In [None]:
!pip install import_ipynb 

In [22]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [23]:
cd /content/drive/MyDrive/retrospective cycle GAN

/content/drive/MyDrive/retrospective cycle GAN


In [24]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset

import numpy as np
import matplotlib.pyplot as plt
import cv2

import glob
import pickle as pkl
import import_ipynb
import os

from dataset_prepration.Dataset import UCF
from models import Generator,Discriminator

In [25]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


# Defining Dataset

In [26]:
train_dataset = UCF(flag = 'train')
train_data_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=1) 

test_dataset = UCF(flag = 'test')
test_data_loader  = torch.utils.data.DataLoader(test_dataset, batch_size=1)

In [27]:
print('train data:',len(train_dataset))
print('test data:',len(test_dataset))

train data: 3844
test data: 962


# Defining Models

In [28]:
generator = Generator(12).to(device)
frame_discriminator  = Discriminator(3).to(device)
sequence_discriminator = Discriminator(15).to(device)

# Setting HyperParameters

In [29]:
lambda1, lambda2, lambda3 = 0.005,0.003,0.003
beta1, beta2 = 0.5,0.999
learning_rate = 0.0003
epochs = 10

# Defining Loss Functions and Optimizers

In [30]:
generator_params = [p for p in generator.parameters() if p.requires_grad]
fdiscriminator_params = [p for p in frame_discriminator.parameters() if p.requires_grad]
sdiscriminator_params = [p for p in sequence_discriminator.parameters() if p.requires_grad]

generator_optimizer = torch.optim.Adam(lr=learning_rate, betas=(beta1, beta2), params=generator_params)
frame_discriminator_oprimizer = torch.optim.Adam(lr=learning_rate, betas=(beta1, beta2), params=fdiscriminator_params)
sequence_discriminator_optimizer = torch.optim.Adam(lr=learning_rate, betas=(beta1, beta2), params=sdiscriminator_params)

In [31]:
scheduler = torch.optim.lr_scheduler.StepLR(generator_optimizer, step_size=20, gamma=0.1)

In [32]:
discriminator_loss_function = nn.MSELoss()
generator_loss_function = nn.L1Loss()

def discriminator_adversarial_loss(real,fake):
  real = torch.squeeze(real)
  fake = torch.squeeze(fake)
  loss = (discriminator_loss_function(real,torch.ones_like(real)) + discriminator_loss_function(fake,torch.zeros_like(fake))) * 0.5
  return loss

def generator_adversarial_loss(fake):
  loss = discriminator_loss_function(fake, torch.ones_like(fake))
  return loss

def image_similarity(model_output,target):
  model_output = torch.squeeze(model_output)
  target = torch.squeeze(target)
  loss = generator_loss_function(model_output,target)
  return loss

def image_similarity_LOG(model_output,target):
  model_output = LOG(torch.squeeze(model_output))
  target = LOG(torch.squeeze(target))

  model_output.requires_grad=True
  loss = generator_loss_function(model_output,target)
  return loss

def LOG(image):
  image = (image+1) * 127.5
  image = image.detach().cpu().numpy().astype(np.uint8)

  image = cv2.GaussianBlur(image, (3,3), 0)
  image = cv2.Laplacian(image, cv2.CV_64F)
  image = torch.tensor(image).to(device)
  image = (image/127.5) -1

  return image

In [33]:
def reverse(images):
  images = images.detach().cpu()
  num_channels = 3
  num_images = images.shape[1] // 3

  all_images = []
  for i in range(num_images):
    all_images.append(images[:,i*3:(i*3)+3,:,:])
  all_images.reverse()

  reversed_images = None
  for image in all_images:
    if reversed_images is None:
      reversed_images = image
    else:
      reversed_images = np.concatenate((reversed_images,image),axis=1)
  
  reversed_images = torch.tensor(reversed_images).to(device)
  return reversed_images

# Training

In [None]:
loss_frame_discriminator_over_epochs = [] 
loss_sequence_discriminator_over_epochs = [] 
loss_generator_over_epochs = [] 
loss_generator_similarity_over_epochs = []
loss_generator_LOG_similarity_over_epochs = []

for epoch in range(epochs):
  loss_frame_discriminator_epoch = 0
  loss_sequence_discriminator_epoch = 0
  loss_generator_epoch = 0
  loss_generator_similarity_epoch = 0
  loss_generator_LOG_similarity_epoch = 0

  for iter,images in enumerate(train_data_loader):
    
    images = images.to(device)
    xm_to_xn = images[:,0:12,:,:] #input 
    xn_plus_one = images[:,12:15,:,:] #target 
    xn_plus_one_dash = generator(xm_to_xn) #prediction

    xm_plus_one_to_xn_plus_ones = images[:,3:15,:,:] #input 
    xm_plus_one_to_xn_plus_ones = reverse(xm_plus_one_to_xn_plus_ones)
    xm = images[:,0:3,:,:] #target
    xm_dash = generator(xm_plus_one_to_xn_plus_ones) #prediction

    xm_plus_one_to_xn_plus_one_dash = torch.cat((images[:,3:12,:,:],xn_plus_one_dash.detach()),dim=1)
    xm_plus_one_to_xn_plus_one_dash = reverse(xm_plus_one_to_xn_plus_one_dash)
    xm_double_dash = generator(xm_plus_one_to_xn_plus_one_dash)

    xm_dash_to_xn = torch.cat((xm_dash.detach(),images[:,3:12,:,:]),dim=1)
    xn_plus_one_double_dash = generator(xm_dash_to_xn)


    #********************training the frame discriminator********************
    frame_discriminator_oprimizer.zero_grad()

    fake_logits_xn_plus_one_dash = frame_discriminator(xn_plus_one_dash.detach())
    fake_logits_xn_plus_one_double_dash = frame_discriminator(xn_plus_one_double_dash.detach())
    fake_logits_xm_dash = frame_discriminator(xm_dash.detach())
    fake_logits_xm_double_dash = frame_discriminator(xm_double_dash.detach())

    real_logits_xn = frame_discriminator(xn_plus_one.detach())
    fake_logits_xm = frame_discriminator(xm.detach())

    loss_frame_discriminator = discriminator_adversarial_loss(real_logits_xn,fake_logits_xn_plus_one_dash)
    loss_frame_discriminator += discriminator_adversarial_loss(real_logits_xn,fake_logits_xn_plus_one_double_dash)
    loss_frame_discriminator += discriminator_adversarial_loss(fake_logits_xm,fake_logits_xm_dash)
    loss_frame_discriminator += discriminator_adversarial_loss(fake_logits_xm,fake_logits_xm_double_dash)
    loss_frame_discriminator = loss_frame_discriminator * lambda2

    loss_frame_discriminator_epoch += loss_frame_discriminator.item() #just saving it
    loss_frame_discriminator.backward()
    frame_discriminator_oprimizer.step()

    # it is important to detach the fake data or else training will not work. on detaching, we are creating a fresh copy of the tensor so that when backward is called on the other network, the tensores associated with the other network are not effected!
    #********************training the frame discriminator********************

    forward_real = images
    backward_real = reverse(images)
    seq2 = torch.cat((xm_plus_one_to_xn_plus_ones,xm_dash),dim=1) #backward fake
    seq3 = torch.cat((images[:,0:12,:,:],xn_plus_one_dash),dim=1) #full forward fake
    seq4 = torch.cat((xm_plus_one_to_xn_plus_ones,xm_double_dash),dim=1) #backward fake
    seq5 = torch.cat((images[:,0:12,:,:],xn_plus_one_double_dash),dim=1) #full forward fake

    #********************training the sequence discriminator********************
    sequence_discriminator_optimizer.zero_grad()

    real_logits_forward = sequence_discriminator(forward_real)
    real_logits_backward = sequence_discriminator(backward_real)

    fake_logits = sequence_discriminator(seq3.detach())
    loss_sequence_discriminator =  discriminator_adversarial_loss(real_logits_forward,fake_logits)
    fake_logits = sequence_discriminator(seq5.detach())
    loss_sequence_discriminator +=  discriminator_adversarial_loss(real_logits_forward,fake_logits)
    fake_logits = sequence_discriminator(seq2.detach())
    loss_sequence_discriminator +=  discriminator_adversarial_loss(real_logits_backward,fake_logits)
    fake_logits = sequence_discriminator(seq4.detach())
    loss_sequence_discriminator +=  discriminator_adversarial_loss(real_logits_backward,fake_logits)

    loss_sequence_discriminator = loss_sequence_discriminator * lambda3
    loss_sequence_discriminator_epoch += loss_sequence_discriminator.item() #just saving it
    loss_sequence_discriminator.backward()
    sequence_discriminator_optimizer.step()
    #********************training the sequence discriminator********************
    

    #********************training the generator********************
    generator_optimizer.zero_grad()

    #calculating similarity losses
    generator_loss_image_similarity = image_similarity(xm_dash,xm) 
    generator_loss_image_similarity += image_similarity(xm_double_dash,xm)
    generator_loss_image_similarity += image_similarity(xm_double_dash,xm_dash.detach())
    generator_loss_image_similarity += image_similarity(xn_plus_one_dash,xn_plus_one)
    generator_loss_image_similarity += image_similarity(xn_plus_one_double_dash,xn_plus_one)
    generator_loss_image_similarity += image_similarity(xn_plus_one_double_dash,xn_plus_one_dash.detach())

    #calculating similarity losses on LOG images

    generator_loss_LOG_image_similarity = image_similarity_LOG(xm_dash,xm) 
    generator_loss_LOG_image_similarity += image_similarity_LOG(xm_double_dash,xm)
    generator_loss_LOG_image_similarity += image_similarity_LOG(xm_double_dash,xm_dash.detach())
    generator_loss_LOG_image_similarity += image_similarity_LOG(xn_plus_one_dash,xn_plus_one) 
    generator_loss_LOG_image_similarity += image_similarity_LOG(xn_plus_one_double_dash,xn_plus_one) 
    generator_loss_LOG_image_similarity += image_similarity_LOG(xn_plus_one_double_dash,xn_plus_one_dash.detach())

    loss_generator = generator_loss_image_similarity + (lambda1 * generator_loss_LOG_image_similarity)
    loss_generator.backward()
    generator_optimizer.step()

    loss_generator_similarity_epoch += generator_loss_image_similarity.item()/6 #just saving it
    loss_generator_LOG_similarity_epoch += generator_loss_LOG_image_similarity.item()/6 #just saving it
    loss_generator_epoch += loss_generator.item() #just saving it

    #********************training the generator********************
  scheduler.step()
  print(f"finished epoch: {epoch}")

  loss_frame_discriminator_over_epochs.append(loss_frame_discriminator_epoch/len(train_data_loader))
  loss_sequence_discriminator_over_epochs.append(loss_sequence_discriminator_epoch/len(train_data_loader))
  loss_generator_over_epochs.append(loss_generator_epoch/len(train_data_loader))

  loss_generator_similarity_over_epochs.append(loss_generator_similarity_epoch/len(train_data_loader))
  loss_generator_LOG_similarity_over_epochs.append(loss_generator_LOG_similarity_epoch/len(train_data_loader))

  print(f"Frame Discriminator Loss: {loss_frame_discriminator_epoch/len(train_data_loader)} *** Sequence Discriminator Loss: {loss_sequence_discriminator_epoch/len(train_data_loader)}")
  print(f"generation reconstruction loss:{loss_generator_similarity_epoch/len(train_data_loader)} *** generation LOG loss: {loss_generator_LOG_similarity_epoch/len(train_data_loader)} ")
  print(f"Loss generator: {loss_generator_epoch/len(train_data_loader)}")
  print("************************************************************")

  torch.save(generator.state_dict(), os.path.join('/content/drive/MyDrive/retrospective cycle GAN/trained_models/',f"Generator{epoch}.pth"))
  torch.save(frame_discriminator.state_dict(), os.path.join('/content/drive/MyDrive/retrospective cycle GAN/trained_models/',f"frame_discriminator{epoch}.pth"))
  torch.save(sequence_discriminator.state_dict(), os.path.join('/content/drive/MyDrive/retrospective cycle GAN/trained_models/',f"sequence_discriminator{epoch}.pth"))



finished epoch: 0
Frame Discriminator Loss: 0.0018079738438132272 *** Sequence Discriminator Loss: 0.00269393568342252
generation reconstruction loss:0.11462849890667866 *** generation LOG loss: 0.03569231198369961 
Loss generator: 0.6888417627995844
************************************************************


# Plotting results

In [None]:
fig = plt.figure(1, figsize=(15,10))

plt.subplot(2,3,1)
plt.plot(loss_frame_discriminator_over_epochs)
plt.title('Frame Discriminator Loss over Epoches')
                     
plt.subplot(2,3,2)
plt.plot(loss_sequence_discriminator_over_epochs)
plt.title('Sequence Discriminator Loss over Epoches')

plt.subplot(2,3,3)
plt.plot(loss_generator_over_epochs)
plt.title('over all generator Loss over Epoches')

plt.subplot(2,3,4)
plt.plot(loss_generator_similarity_over_epochs)
plt.title('L1-image similarity Loss over Epoches')

plt.subplot(2,3,5)
plt.plot(loss_generator_LOG_similarity_over_epochs)
plt.title('L1-LOG similarity Loss over Epoches')
plt.show()