# Conditional SeqGAN - 24


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
! /opt/bin/nvidia-smi

In [None]:
%cd "/content/drive/MyDrive/ConditionalSeqGAN"
!ls

In [None]:
import seed

from gen_opts import gen_opts
from gen_dataset import gen_dataset
from gen_dataloader import get_gen_iter

from encoder_decoder import encoder
from encoder_decoder import decoder
from generator import generator

from gen_optimizer import encoder_optim, decoder_optim, encoder_optim_scheduler, decoder_optim_scheduler
from gen_train_epoch import train_gen

import torch

In [None]:
def load_checkpoint(checkpoint_PATH, encoder, decoder, encoder_optim, decoder_optim):
    if checkpoint_PATH != None:
        model_CKPT = torch.load(checkpoint_PATH)
        encoder.load_state_dict(model_CKPT['encoder_state_dict'])
        decoder.load_state_dict(model_CKPT['decoder_state_dict'])
        encoder_optim.load_state_dict(model_CKPT['encoder_optim_state_dict'])
        decoder_optim.load_state_dict(model_CKPT['decoder_optim_state_dict'])
    return encoder, decoder, encoder_optim, decoder_optim

In [None]:
checkpoint_path = '/content/drive/MyDrive/ConditionalSeqGAN/checkpoints/generator/A Deep Reinforced Generative Adversarial Network for Abstractive Text Summarization_2020-12-04 18:58:07_epoch_1_iter_3800_loss_0.17_step_100.pt'
encoder, decoder, encoder_optim, decoder_optim = load_checkpoint(checkpoint_path, encoder, decoder, encoder_optim, decoder_optim)

## Generator Pre-training

In [None]:
gen_iter = get_gen_iter(gen_dataset=gen_dataset, batch_size=gen_opts.batch_size)

In [None]:
# if colab terminate the code then restart from here.
continue_point = 0

In [None]:
train_gen(dataset = gen_dataset, encoder = encoder, decoder = decoder, encoder_optim = encoder_optim, decoder_optim = decoder_optim, encoder_optim_scheduler = encoder_optim_scheduler, decoder_optim_scheduler = decoder_optim_scheduler, num_epochs = gen_opts.num_epochs, gen_iter = gen_iter, save_every_step = gen_opts.save_every_step, print_every_step = gen_opts.print_every_step, continue_point = continue_point) 

## Discriminator Pre-training

In [None]:
%cd "/content/drive/MyDrive/ConditionalSeqGAN"
!ls

In [None]:
from gen_dataset import gen_dataset

from dis_opts import dis_opts
from discriminator import discriminator
from write_dis_dataset import get_training_pairs
from dis_dataloader import get_dis_iter
from dis_train_epoch import train_dis
from dis_optimizer import dis_optim

import torch
import seed

In [None]:
training_pairs = get_training_pairs(gen_dataset)

In [None]:
dis_iter = get_dis_iter(training_pairs=training_pairs, num_workers=0)

In [None]:
train_dis(discriminator=discriminator, 
          dis_optim=dis_optim, 
          num_epochs=dis_opts.num_epochs, 
          dis_iter=dis_iter, 
          save_every_step=dis_opts.save_every_step, 
          print_every_step=dis_opts.print_every_step)

## Adversarial Training

In [None]:
%cd "/content/drive/MyDrive/ConditionalSeqGAN"
!ls

In [None]:
from GAN_opts import GAN_opts
from adversarial_train_epoch import train_adversarial
import seed
import torch

In [None]:
gen_iter = get_gen_iter(gen_dataset=gen_dataset,
                        batch_size=GAN_opts.batch_size,
                        num_workers=2)

In [None]:
train_adversarial(dataset=gen_dataset, #
                  generator=generator,
                  discriminator=discriminator,
                  encoder_optim=encoder_optim,
                  decoder_optim=decoder_optim,
                  dis_optim=dis_optim, 
                  gen_iter=gen_iter,
                  gen_dataset=gen_dataset,
                  num_epochs=1, 
                  print_every_step=GAN_opts.G_print_every_step, 
                  save_every_step=GAN_opts.G_save_every_step,
                  num_rollout=GAN_opts.num_rollout)

print("Finished!")