Skip to content

Commit

Permalink
changed analogy
Browse files Browse the repository at this point in the history
  • Loading branch information
samaonline committed Jun 7, 2017
1 parent 6fa1636 commit cfc9157
Showing 1 changed file with 66 additions and 80 deletions.
146 changes: 66 additions & 80 deletions analogy_network.ipynb
Expand Up @@ -5,8 +5,8 @@
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2017-05-06T07:32:45.807677",
"start_time": "2017-05-06T00:32:45.793659-07:00"
"end_time": "2017-05-07T08:29:50.360225",
"start_time": "2017-05-07T01:29:50.208659-07:00"
},
"collapsed": true
},
Expand All @@ -22,7 +22,7 @@
"import os\n",
"import random\n",
"from scipy.misc import imread\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\""
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\""
]
},
{
Expand Down Expand Up @@ -197,9 +197,15 @@
" l2_loss = mean_l2(self.x_t_n, x_t_n_predicted)\n",
" feat_loss = mean_l2(self.C1(self.x_t), self.C1(x_t_n_predicted))\n",
" adv_loss = -tf.reduce_mean(tf.log(discriminator.get_output_tensor(x_t_n_predicted, self.p_t_n)))\n",
" self.loss = l2_loss + feat_loss + 0.05*adv_loss\n",
" self.opt = tf.train.GradientDescentOptimizer(learning_rate=1e-3).minimize(self.loss, var_list=self.train_variables)\n",
" \n",
" self.loss = 100*l2_loss + 100*feat_loss + 0.05 * adv_loss\n",
" self.opt = tf.train.GradientDescentOptimizer(learning_rate=1e-5).minimize(self.loss, var_list=self.train_variables)\n",
" with tf.name_scope('generator'):\n",
" l2_loss_summ = tf.summary.scalar('l2_loss', l2_loss)\n",
" feature_loss_summ = tf.summary.scalar('feature_loss', feat_loss)\n",
" adversarial_loss_summ = tf.summary.scalar('adversarial_loss', adv_loss)\n",
" loss_summ = tf.summary.scalar('loss', self.loss)\n",
" self.summaries = tf.summary.merge([l2_loss_summ, feature_loss_summ, adversarial_loss_summ, loss_summ])\n",
" \n",
" def get_output_tensor(self, p_t_n, p_t, x_t):\n",
" with tf.variable_scope('generator', reuse=self.has_defined_layers):\n",
" p_t_n_latent = self.f_pose(p_t_n)\n",
Expand Down Expand Up @@ -295,9 +301,9 @@
" sess.run(W.assign(W_value))\n",
" sess.run(b.assign(b_value))\n",
" \n",
" def fit_batch(self, sess, p_t, p_t_n, x_t, x_t_n):\n",
" _, loss = sess.run((self.opt, self.loss), feed_dict={ self.p_t : p_t, self.p_t_n : p_t_n, self.x_t : x_t, self.x_t_n : x_t_n })\n",
" return loss\n",
" def fit_batch(self,sess, p_t, p_t_n, x_t, x_t_n):\n",
" _, loss, summaries = sess.run((self.opt, self.loss, self.summaries), feed_dict={ self.p_t : p_t, self.p_t_n : p_t_n, self.x_t : x_t, self.x_t_n : x_t_n })\n",
" return loss, summaries\n",
"\n",
"class Discriminator(object):\n",
" def __init__(self):\n",
Expand All @@ -316,9 +322,18 @@
" fake_prob = self.get_output_tensor(x_t_n_pred, self.p_t_n)\n",
" real_mismatch_prob = self.get_output_tensor(self.x_t, self.p_t_n)\n",
" \n",
" self.loss = -tf.reduce_mean(tf.log(real_prob) + 0.5 * tf.log(1 - fake_prob) + 0.5 * tf.log(1 - real_mismatch_prob))\n",
" self.opt = tf.train.GradientDescentOptimizer(learning_rate=1e-3).minimize(self.loss, var_list=self.train_variables)\n",
" \n",
" real_loss = -tf.reduce_mean(tf.log(real_prob))\n",
" fake_loss = -tf.reduce_mean(tf.log(1 - fake_prob))\n",
" mismatch_loss = -tf.reduce_mean(tf.log(1 - real_mismatch_prob))\n",
" self.loss = real_loss + 0.5 * fake_loss + 0.5 * mismatch_loss\n",
" self.opt = tf.train.GradientDescentOptimizer(learning_rate=1e-5).minimize(self.loss, var_list=self.train_variables)\n",
" with tf.name_scope('discriminator'):\n",
" real_loss_summ = tf.summary.scalar('real_loss', real_loss)\n",
" fake_loss_summ = tf.summary.scalar('fake_loss', fake_loss)\n",
" mismatch_loss_summ = tf.summary.scalar('mismatch_loss', mismatch_loss)\n",
" loss_summ = tf.summary.scalar('loss', self.loss)\n",
" self.summaries = tf.summary.merge([real_loss_summ, fake_loss_summ, mismatch_loss_summ, loss_summ])\n",
" \n",
" def get_output_tensor(self, x, p):\n",
" with tf.variable_scope('discriminator', reuse=self.has_defined_layers):\n",
" with tf.variable_scope('f_img'):\n",
Expand Down Expand Up @@ -351,8 +366,8 @@
" sess.run(b.assign(b_value))\n",
" \n",
" def fit_batch(self, sess, p_t, p_t_n, x_t, x_t_n):\n",
" _, loss = sess.run((self.opt, self.loss), feed_dict={ self.p_t : p_t, self.p_t_n : p_t_n, self.x_t : x_t, self.x_t_n : x_t_n })\n",
" return loss"
" _, loss, summaries = sess.run((self.opt, self.loss, self.summaries), feed_dict={ self.p_t : p_t, self.p_t_n : p_t_n, self.x_t : x_t, self.x_t_n : x_t_n })\n",
" return loss, summaries"
]
},
{
Expand Down Expand Up @@ -389,7 +404,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2017-05-06T07:36:57.792490",
Expand All @@ -399,12 +414,12 @@
},
"outputs": [],
"source": [
"L = 1"
"L = 13"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"start_time": "2017-05-06T07:36:58.137Z"
Expand All @@ -415,6 +430,7 @@
"source": [
"tf.reset_default_graph()\n",
"sess = tf.Session()\n",
"summary_writer = tf.summary.FileWriter('summaries/', graph=sess.graph)\n",
"\n",
"generator = Generator()\n",
"discriminator = Discriminator()\n",
Expand All @@ -430,7 +446,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 8,
"metadata": {
"collapsed": false,
"scrolled": true
Expand All @@ -445,38 +461,39 @@
" for frame in os.listdir('squats/' + str(video) + '/'):\n",
" filename = 'squats/' + str(video) + '/' + str(frame)\n",
" videos[video].append(imread(filename))\n",
" \n",
"\n",
"# Create dictionary of heatmat labels for squat videos. \n",
"# For L = 13\n",
"# Key = video number, Value = list of stack of joints (numpy array images (224x224x13))\n",
"if L == 13:\n",
" labels = {}\n",
" for video in os.listdir('squats_labels_multiple/'):\n",
" labels[video] = []\n",
" for frame in os.listdir('squats_labels_multiple/' + str(video) + '/'):\n",
" frame_folder = 'squats_labels_multiple/' + str(video) + '/' + str(frame) + '/'\n",
" temp_image_stack = np.zeros((224,224,13))\n",
" i = 0\n",
" for filename in os.listdir(frame_folder):\n",
" temp_image_stack[:,:,i] = imread(frame_folder + filename)\n",
" i = i + 1\n",
" labels[video].append(temp_image_stack)\n",
" \n",
" if video in videos:\n",
" labels[video] = []\n",
" for frame in os.listdir('squats_labels_multiple/' + str(video) + '/'):\n",
" frame_folder = 'squats_labels_multiple/' + str(video) + '/' + str(frame) + '/'\n",
" temp_image_stack = np.zeros((224,224,13))\n",
" i = 0\n",
" for filename in os.listdir(frame_folder):\n",
" temp_image_stack[:,:,i] = imread(frame_folder + filename)\n",
" i = i + 1\n",
" labels[video].append(temp_image_stack)\n",
"\n",
"# For L = 1 \n",
"# Key = video number, Value = list of heatmaps for each frame (numpy array images (224x224x1)) \n",
"elif L == 1:\n",
" labels = {}\n",
" for video in os.listdir('squats_labels/'):\n",
" labels[video] = []\n",
" for frame in os.listdir('squats_labels/' + str(video) + '/'):\n",
" filename = 'squats_labels/' + str(video) + '/' + str(frame)\n",
" labels[video].append(imread(filename).reshape((224,224,1)))"
" if video in videos:\n",
" labels[video] = []\n",
" for frame in os.listdir('squats_labels/' + str(video) + '/'):\n",
" filename = 'squats_labels/' + str(video) + '/' + str(frame)\n",
" labels[video].append(imread(filename).reshape((224,224,1)))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {
"collapsed": true
},
Expand All @@ -502,26 +519,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(224, 224, 3)\n"
]
}
],
"source": [
"print videos['1659'][0].shape"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": true
Expand All @@ -531,33 +529,17 @@
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1: gen_loss=nan, disc_loss=nan\n",
"epoch 2: gen_loss=nan, disc_loss=nan\n",
"epoch 3: gen_loss=nan, disc_loss=nan\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-12-25cd9f54ffe3>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtotal_iter\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[0mf1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mf2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mh1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mh2\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcreate_minibatch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 14\u001b[1;33m \u001b[0mgen_loss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgenerator\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit_batch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msess\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mh1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mh2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mf1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mf2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 15\u001b[0m \u001b[0mdisc_loss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdiscriminator\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit_batch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msess\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mh1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mh2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mf1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mf2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[0mtotal_gen_loss\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mgen_loss\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m<ipython-input-3-b5462c8d03f0>\u001b[0m in \u001b[0;36mfit_batch\u001b[1;34m(self, sess, p_t, p_t_n, x_t, x_t_n)\u001b[0m\n\u001b[0;32m 114\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 115\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mfit_batch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msess\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mp_t\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mp_t_n\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx_t\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx_t_n\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 116\u001b[1;33m \u001b[0m_\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msess\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mopt\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mloss\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m{\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mp_t\u001b[0m \u001b[1;33m:\u001b[0m \u001b[0mp_t\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mp_t_n\u001b[0m \u001b[1;33m:\u001b[0m \u001b[0mp_t_n\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mx_t\u001b[0m \u001b[1;33m:\u001b[0m \u001b[0mx_t\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mx_t_n\u001b[0m \u001b[1;33m:\u001b[0m \u001b[0mx_t_n\u001b[0m \u001b[1;33m}\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 117\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 118\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m/home/jeffzhang/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 776\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 777\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[1;32m--> 778\u001b[1;33m run_metadata_ptr)\n\u001b[0m\u001b[0;32m 779\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 780\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m/home/jeffzhang/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_run\u001b[1;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 980\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 981\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[1;32m--> 982\u001b[1;33m feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[0;32m 983\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 984\u001b[0m \u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m/home/jeffzhang/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_do_run\u001b[1;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 1030\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[1;32mis\u001b[0m \u001b[0mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1031\u001b[0m return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[1;32m-> 1032\u001b[1;33m target_list, options, run_metadata)\n\u001b[0m\u001b[0;32m 1033\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1034\u001b[0m return self._do_call(_prun_fn, self._session, handle, feed_dict,\n",
"\u001b[1;32m/home/jeffzhang/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_do_call\u001b[1;34m(self, fn, *args)\u001b[0m\n\u001b[0;32m 1037\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1038\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1039\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1040\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1041\u001b[0m \u001b[0mmessage\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m/home/jeffzhang/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_run_fn\u001b[1;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[0;32m 1019\u001b[0m return tf_session.TF_Run(session, options,\n\u001b[0;32m 1020\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1021\u001b[1;33m status, run_metadata)\n\u001b[0m\u001b[0;32m 1022\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1023\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msession\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
"epoch 1: gen_loss=inf, disc_loss=1.5596\n",
"epoch 2: gen_loss=1358344.4966, disc_loss=1.4132\n",
"epoch 3: gen_loss=1378886.7206, disc_loss=1.3245\n"
]
}
],
"source": [
"epochs = 100\n",
"n_samples = 100\n",
"batch_size = 16\n",
"n_samples = 1000\n",
"batch_size = 2\n",
"summary_freq_iter = 10\n",
"display_step = 1\n",
"\n",
"mean_gen_losses = []\n",
Expand All @@ -568,19 +550,23 @@
" total_disc_loss = 0\n",
" for i in range(total_iter):\n",
" f1,f2,h1,h2 = create_minibatch(batch_size)\n",
" gen_loss = generator.fit_batch(sess,h1,h2,f1,f2)\n",
" disc_loss = discriminator.fit_batch(sess,h1,h2,f1,f2)\n",
" gen_loss, gen_summaries = generator.fit_batch(sess,h1,h2,f1,f2)\n",
" disc_loss, disc_summaries = discriminator.fit_batch(sess,h1,h2,f1,f2)\n",
" total_gen_loss += gen_loss\n",
" total_disc_loss += disc_loss\n",
" if i % summary_freq_iter == 0:\n",
" step = epoch * n_samples + (i + 1) * batch_size\n",
" summary_writer.add_summary(gen_summaries, step)\n",
" summary_writer.add_summary(disc_summaries, step)\n",
" mean_gen_loss = total_gen_loss / total_iter\n",
" mean_disc_loss = total_disc_loss / total_iter\n",
" mean_gen_losses.append(mean_gen_loss)\n",
" mean_disc_losses.append(mean_disc_loss)\n",
" if (epoch + 1) % display_step == 0:\n",
" print('epoch %s: gen_loss=%.4f, disc_loss=%.4f' % (epoch + 1, mean_gen_loss, mean_disc_loss))\n",
"\n",
"# saver = tf.train.Saver()\n",
"# saver.save(sess, 'test_model', global_step = epochs)"
"saver = tf.train.Saver()\n",
"saver.save(sess, '/media/jeffzhang/WD HDD/model/multi-labels-test3-gradient-opt-100-100-005',global_step=100)"
]
},
{
Expand Down

0 comments on commit cfc9157

Please sign in to comment.