To use this notebook, just run it from top to bottom

Imports

In [None]:
import torch
from torch import nn
from PIL import Image
import numpy as np
import torch.nn.functional as F
import glob as glob
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torchvision
from skimage import color

Model Code

In [None]:
class ContractingBlock(nn.Module):
  def __init__(self, input_channels, use_bn=True):
    super(ContractingBlock, self).__init__()
    self.conv1 = nn.Conv2d(input_channels, input_channels*2, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(input_channels*2, input_channels*2, kernel_size=3, padding=1)
    self.activation = nn.LeakyReLU(0.2)
    self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
    if use_bn:
      self.batchnorm = nn.BatchNorm2d(input_channels * 2)
    self.use_bn = use_bn

  def forward(self, x):
    x = self.conv1(x)
    if self.use_bn:
      x = self.batchnorm(x)
    x = self.activation(x)
    x = self.conv2(x)
    if self.use_bn:
      x = self.batchnorm(x)
    x = self.activation(x)
    x = self.maxpool(x)
    return x

def crop(image, new_shape):
  cropped_image = image[:,:,int((image.shape[2]-new_shape[2])/2):int((image.shape[2]+new_shape[2])/2),int((image.shape[3]-new_shape[3])/2):int((image.shape[3]+new_shape[3])/2)]
  return cropped_image

class ExpandingBlock(nn.Module):
  def __init__(self, input_channels, use_bn = True):
    super(ExpandingBlock, self).__init__()
    self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
    self.conv1 = nn.Conv2d(input_channels, int(input_channels/2), kernel_size=2)
    self.conv2 = nn.Conv2d(input_channels, int(input_channels/2), kernel_size=3, padding=1)
    self.conv3 = nn.Conv2d(int(input_channels/2), int(input_channels/2), kernel_size=2, padding = 1)
    if use_bn:
      self.batchnorm = nn.BatchNorm2d(input_channels // 2)
      self.use_bn = use_bn
    self.activation = nn.ReLU()

  def forward(self, x, skip_con_x):
    x = self.upsample(x)
    x = self.conv1(x)
    skip_con_x = crop(skip_con_x, x.shape)
    x = torch.cat([x, skip_con_x], axis=1)
    x = self.conv2(x)
    if self.use_bn:
      x = self.batchnorm(x)
    x = self.activation(x)
    x = self.conv3(x)
    if self.use_bn:
      x = self.batchnorm(x)
    x = self.activation(x)
    return x

class FeatureMapBlock(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(FeatureMapBlock, self).__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=1)

    def forward(self, x):
        x = self.conv(x)
        return x

class Generator(nn.Module):
  def __init__(self, input_channels, output_channels, hidden_channels=64):
    super(Generator, self).__init__()
    self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
    self.contract1 = ContractingBlock(hidden_channels)
    self.contract2 = ContractingBlock(hidden_channels * 2)
    self.contract3 = ContractingBlock(hidden_channels * 4)
    self.contract4 = ContractingBlock(hidden_channels * 8)
    self.expand1 = ExpandingBlock(hidden_channels * 16)
    self.expand2 = ExpandingBlock(hidden_channels * 8)
    self.expand3 = ExpandingBlock(hidden_channels * 4)
    self.expand4 = ExpandingBlock(hidden_channels * 2)
    self.downfeature = FeatureMapBlock(hidden_channels, output_channels)
    self.sigmoid = torch.nn.Sigmoid()
  
  def forward(self,x):
    x0 = self.upfeature(x)
    x1 = self.contract1(x0)
    x2 = self.contract2(x1)
    x3 = self.contract3(x2)
    x4 = self.contract4(x3)
    x5 = self.expand1(x4, x3)
    x6 = self.expand2(x5, x2)
    x7 = self.expand3(x6, x1)
    x8 = self.expand4(x7, x0)
    xn = self.downfeature(x8)
    return self.sigmoid(xn)

class Discriminator(nn.Module):
    '''
    Discriminator Class
    Structured like the contracting path of the U-Net, the discriminator will
    output a matrix of values classifying corresponding portions of the image as real or fake. 
    Parameters:
        input_channels: the number of image input channels
        hidden_channels: the initial number of discriminator convolutional filters
    '''
    def __init__(self, input_channels, hidden_channels=8):
        super(Discriminator, self).__init__()
        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels, use_bn=False)
        self.contract2 = ContractingBlock(hidden_channels * 2)
        self.contract3 = ContractingBlock(hidden_channels * 4)
        self.contract4 = ContractingBlock(hidden_channels * 8)
        #### START CODE HERE ####
        self.final = nn.Conv2d(hidden_channels * 16, 1, kernel_size=1)
        #### END CODE HERE ####

    def forward(self, x, y):
        x = torch.cat([x, y], axis=1)
        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.contract3(x2)
        x4 = self.contract4(x3)
        xn = self.final(x4)
        return xn

Mount your Google Drive

In [None]:
# Connect Google Drive
from google.colab import drive
drive.mount('/content/gdrive')
print('Google Drive connected.')

Mounted at /content/gdrive
Google Drive connected.


In [None]:
#@markdown ## Input File
Input_File = "CS236G_Project/DAIN_test/test.mkv" #@param{type:"string"}

#@markdown ## Frame Location
Frame_Storage = "CS236G_Project/test_frames" #@param{type:"string"}

#@markdown ## Chunk Location
Chunk_Storage = "CS236G_Project/testing_chunks" #@param{type:"string"}

Example code converting video input into frame data shown below

In [None]:
%shell ffmpeg -i '/content/gdrive/MyDrive/{Input_File}' -vf 'select=gte(n\,1),setpts=PTS-STARTPTS,scale=704:480' '/content/gdrive/MyDrive/{Frame_Storage}/default/%05d.png'

Parameter Setting, note, when device is set to 'cuda', you need to be on GPU for the notebook to run

In [None]:
# New parameters
adv_criterion = nn.BCEWithLogitsLoss() 
recon_criterion = nn.L1Loss() 
lambda_recon = 200

n_epochs = 2
input_dim = 6
real_dim = 3
display_step = 500
batch_size = 1
lr = 0.0002
target_shape = 256
device = 'cuda'

Build Generator and Discriminator

In [None]:
gen = Generator(input_dim, real_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(input_dim + real_dim).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

Fill in your checkpoint location below if you want to train from checkpoint

In [None]:
# Your checkpoint location goes here
save_file_location = "/content/gdrive/MyDrive/CS236G_Project/opgan2_14000.pth"
pretrained = True

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

if pretrained:
    loaded_state = torch.load(save_file_location)
    gen.load_state_dict(loaded_state["gen"])
    gen_opt.load_state_dict(loaded_state["gen_opt"])
    disc.load_state_dict(loaded_state["disc"])
    disc_opt.load_state_dict(loaded_state["disc_opt"])
else:
    gen = gen.apply(weights_init)
    disc = disc.apply(weights_init)

Define Generator Loss function

In [None]:
def get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon):
    '''
    Return the loss of the generator given inputs.
    Parameters:
        gen: the generator; takes the condition and returns potential images
        disc: the discriminator; takes images and the condition and
          returns real/fake prediction matrices
        real: the real images (e.g. maps) to be used to evaluate the reconstruction
        condition: the source images (e.g. satellite imagery) which are used to produce the real images
        adv_criterion: the adversarial loss function; takes the discriminator 
                  predictions and the true labels and returns a adversarial 
                  loss (which you aim to minimize)
        recon_criterion: the reconstruction loss function; takes the generator 
                    outputs and the real images and returns a reconstructuion 
                    loss (which you aim to minimize)
        lambda_recon: the degree to which the reconstruction loss should be weighted in the sum
    '''
    fake = gen(condition)
    disc_label = disc(fake, condition)
    adv_loss = adv_criterion(disc_label, torch.ones_like(disc_label))
    recon_loss = recon_criterion(fake, real)
    gen_loss = adv_loss + lambda_recon * recon_loss
    return gen_loss

Define function that converts frames to chunks and saves them

In [None]:
def image2chunk(folder,image_index, frame_gap = 1):
  a_name = '/content/gdrive/MyDrive/{}/{}/{}.png'.format(Frame_Storage, folder, f'{image_index:05}')
  b_name = '/content/gdrive/MyDrive/{}/{}/{}.png'.format(Frame_Storage, folder, f'{(image_index + frame_gap):05}')
  c_name = '/content/gdrive/MyDrive/{}/{}/{}.png'.format(Frame_Storage, folder, f'{(image_index + 2*frame_gap):05}')
  # print(a_name)
  a = np.array(Image.open(a_name))/255
  b = np.array(Image.open(b_name))/255
  c = np.array(Image.open(c_name))/255
  chunk = np.concatenate((a,c,b), axis = 1)
  chunk_im = Image.fromarray((chunk * 255).astype(np.uint8))
  chunk_im.save('/content/gdrive/MyDrive/{}/{}/{}.png'.format(Chunk_Storage, folder, f'{image_index:05}'))

This code box will search through the listed subfolders of the folder you specify, convert the frames into chunks, then save the chunks in the corresponding subfolder in the Chunk Storage folder you specified previously.

In [None]:
frame_gap = 1
folders = ['a', 'b', 'c1', 'c2']

for folder in folders:
  frame_count = len(list(glob.iglob("/content/gdrive/MyDrive/{}/{}/*.png".format(Frame_Storage, folder))))
  todo = frame_count - 2*frame_gap
  print("Commencing on frame folder {}".format(folder))
  for i in range(1, todo+1):
    image2chunk(folder, i, frame_gap)
    if i%50 == 0:
      print("Processed {} out of {} on todo list".format(i, todo))
  print("Finished on frame folder {}".format(folder))

Visualization function


In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

Define some needed transforms for later

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

Create the training dataset

In [None]:
# Replace Chunk_Storage with wherever you saved the actual training chunks
training_chunk_storage = Chunk_Storage

dataset = torchvision.datasets.ImageFolder('/content/gdrive/MyDrive/{}'.format(training_chunk_storage), transform=transform)
dataset_len = len(dataset)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [dataset_len - dataset_len//5, dataset_len//5])

Create the testing dataset

In [None]:
test_dataset = torchvision.datasets.ImageFolder('/content/gdrive/MyDrive/{}'.format(Chunk_Storage), transform=transform)

Define the training function

In [None]:
train_gen = []
train_disc = []
val_gen = []
val_disc = []

# How often this model saves. You can modify it, as a heads up, each checkpoint takes up like 300 MB of space
save_rate = 2000

# Folder where checkpoints are saved, feel free to replace with whatever you want
save_location = "CS236G_Project"

def train(save_model=False):
    mean_generator_loss = 0
    mean_discriminator_loss = 0
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    cur_step = 0
    val_count_limit = 100

    for epoch in range(n_epochs):
        # Dataloader returns the batches
        for image, _ in tqdm(train_dataloader):
            #image_width = image.shape[3]
            #condition = image[:, :, :, :image_width // 2]
            #condition = nn.functional.interpolate(condition, size=target_shape)
            #real = image[:, :, :, image_width // 2:]
            #real = nn.functional.interpolate(real, size=target_shape)

            image_width = image.shape[3]
            pre = image[:, :, :, :image_width // 3]
            post = image[:, :, :, image_width // 3:2*image_width // 3]
            condition = torch.cat((pre, post), dim=1)
            real = image[:, :, :, 2*image_width // 3:]

            cur_batch_size = len(condition)
            condition = condition.to(device)
            real = real.to(device)

            ### Update discriminator ###
            disc_opt.zero_grad() # Zero out the gradient before backpropagation
            with torch.no_grad():
                fake = gen(condition)
            disc_fake_hat = disc(fake.detach(), condition) # Detach generator
            disc_fake_loss = adv_criterion(disc_fake_hat, torch.zeros_like(disc_fake_hat))
            disc_real_hat = disc(real, condition)
            disc_real_loss = adv_criterion(disc_real_hat, torch.ones_like(disc_real_hat))
            disc_loss = (disc_fake_loss + disc_real_loss) / 2
            disc_loss.backward(retain_graph=True) # Update gradients
            disc_opt.step() # Update optimizer

            ### Update generator ###
            gen_opt.zero_grad()
            gen_loss = get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon)
            gen_loss.backward() # Update gradients
            gen_opt.step() # Update optimizer

            # Keep track of the average discriminator loss
            mean_discriminator_loss += disc_loss.item() / display_step
            # Keep track of the average generator loss
            mean_generator_loss += gen_loss.item() / display_step

            ### Visualization code ###
            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(f"Epoch {epoch}: Step {cur_step}: Generator (U-Net) loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}")
                    train_gen.append(mean_generator_loss)
                    train_disc.append(mean_discriminator_loss)
                else:
                    print("Pretrained initial state")
                #show_tensor_images(condition, size=(input_dim, target_shape, target_shape))
                show_tensor_images(real, size=(3, 480, 704))
                show_tensor_images(fake, size=(3, 480, 704))

                mean_generator_loss = 0
                mean_discriminator_loss = 0

                val_count = 0
                val_mean_gen_loss = 0
                val_mean_disc_loss = 0
                for image, _ in tqdm(val_dataloader):
                  image_width = image.shape[3]
                  pre = image[:, :, :, :image_width // 3]
                  post = image[:, :, :, image_width // 3:2*image_width // 3]
                  condition = torch.cat((pre, post), dim=1)
                  real = image[:, :, :, 2*image_width // 3:]

                  cur_batch_size = len(condition)
                  condition = condition.to(device)
                  real = real.to(device)

                  disc_opt.zero_grad() # Zero out the gradient before backpropagation
                  with torch.no_grad():
                      fake = gen(condition)
                  disc_fake_hat = disc(fake.detach(), condition) # Detach generator
                  disc_fake_loss = adv_criterion(disc_fake_hat, torch.zeros_like(disc_fake_hat))
                  disc_real_hat = disc(real, condition)
                  disc_real_loss = adv_criterion(disc_real_hat, torch.ones_like(disc_real_hat))
                  disc_loss = (disc_fake_loss + disc_real_loss) / 2
                  #disc_loss.backward(retain_graph=True) # Update gradients
                  #disc_opt.step() # Update optimizer

                  ### Update generator ###
                  gen_opt.zero_grad()
                  gen_loss = get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon)
                  #gen_loss.backward() # Update gradients
                  #gen_opt.step() # Update optimizer

                  val_mean_gen_loss += gen_loss.item() / val_count_limit
                  val_mean_disc_loss += disc_loss.item() / val_count_limit
                  
                  val_count += 1
                  if val_count >= val_count_limit:
                    break
                print("Validation Set Gen Loss: {}, Validation Set Disc Loss: {}".format(val_mean_gen_loss, val_mean_disc_loss))
                val_gen.append(val_mean_gen_loss)
                val_disc.append(val_mean_disc_loss)

                # You can change save_model to True if you'd like to save the model
            if cur_step % save_rate == 0:
                if save_model:
                    torch.save({'gen': gen.state_dict(),
                        'gen_opt': gen_opt.state_dict(),
                        'disc': disc.state_dict(),
                        'disc_opt': disc_opt.state_dict()
                    }, f"/content/gdrive/MyDrive/{save_location}/opgan2_{cur_step}.pth")
            cur_step += 1

Define the testing function

In [None]:
def test_generator(test_count_limit = 300):
  test_count = 0
  test_mean_gen_loss = 0
  test_mean_disc_loss = 0
  test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
  for image, _ in tqdm(test_dataloader):
    image_width = image.shape[3]
    pre = image[:, :, :, :image_width // 3]
    post = image[:, :, :, image_width // 3:2*image_width // 3]
    condition = torch.cat((pre, post), dim=1)
    real = image[:, :, :, 2*image_width // 3:]

    cur_batch_size = len(condition)
    condition = condition.to(device)
    real = real.to(device)

    disc_opt.zero_grad() # Zero out the gradient before backpropagation
    with torch.no_grad():
        fake = gen(condition)
    disc_fake_hat = disc(fake.detach(), condition) # Detach generator
    disc_fake_loss = adv_criterion(disc_fake_hat, torch.zeros_like(disc_fake_hat))
    disc_real_hat = disc(real, condition)
    disc_real_loss = adv_criterion(disc_real_hat, torch.ones_like(disc_real_hat))
    disc_loss = (disc_fake_loss + disc_real_loss) / 2

    gen_opt.zero_grad()
    gen_loss = get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon)

    test_mean_gen_loss += gen_loss.item() / test_count_limit
    test_mean_disc_loss += disc_loss.item() / test_count_limit

    if test_count % 10 == 0:
      show_tensor_images(real, size=(3, 480, 704))
      show_tensor_images(fake, size=(3, 480, 704))

    test_count += 1
    if test_count >= test_count_limit:
      break
  print("Test Set Gen Loss: {}, Test Set Disc Loss: {}".format(test_mean_gen_loss, test_mean_disc_loss))

Run the training function

In [None]:
train(save_model=True)

Plot Generator loss

In [None]:
train_gen_spaced = np.zeros((len(train_gen), 2))
for i in range(len(train_gen)):
  train_gen_spaced[i,0] = (i+1)*500
  train_gen_spaced[i,1] = train_gen[i]

val_gen_spaced = np.zeros((len(val_gen), 2))
for i in range(len(val_gen)):
  val_gen_spaced[i,0] = i*500
  val_gen_spaced[i,1] = val_gen[i]

plt.plot(train_gen_spaced[:,0], train_gen_spaced[:,1])
plt.plot(val_gen_spaced[:,0], val_gen_spaced[:,1])

Plot Discriminator loss

In [None]:
train_disc_spaced = np.zeros((len(train_disc), 2))
for i in range(len(train_disc)):
  train_disc_spaced[i,0] = (i+1)*500
  train_disc_spaced[i,1] = train_disc[i]

val_disc_spaced = np.zeros((len(val_disc), 2))
for i in range(len(val_disc)):
  val_disc_spaced[i,0] = i*500
  val_disc_spaced[i,1] = val_disc[i]

plt.plot(train_disc_spaced[:,0], train_disc_spaced[:,1])
plt.plot(val_disc_spaced[:,0], val_disc_spaced[:,1])

Run the testing function

In [None]:
test_generator()

Experimental: Load in a checkpoint and use it to do sample interpolation

In [None]:
loaded_state = torch.load("/content/gdrive/MyDrive/CS236G_Project/opgan_1000.pth")
gen.load_state_dict(loaded_state["gen"])

In [None]:
def synthesizeframe(frame1, frame2):
  frame1 = np.expand_dims(frame1, axis = 0)
  frame2 = np.expand_dims(frame2, axis = 0)
  ba = torch.FloatTensor(np.transpose(frame1, (0, 3, 1, 2)))
  aa = torch.FloatTensor(np.transpose(frame2, (0, 3, 1, 2)))
  test_input = torch.cat((ba, aa), dim=1).to(device)
  output = gen(test_input)
  show_tensor_images(ba, size=(3, 480, 704))
  show_tensor_images(output, size=(3, 480, 704))
  show_tensor_images(aa, size=(3, 480, 704))

In [None]:
def demo_interpolate(folder, image_index, frame_gap = 2,saveframe=False):
  frame1 = np.array(Image.open('/content/gdrive/MyDrive/{}/{}/{}.png'.format(Frame_Storage, folder, f'{image_index:05}')))/255
  frame2 = np.array(Image.open('/content/gdrive/MyDrive/{}/{}/{}.png'.format(Frame_Storage, folder, f'{image_index+frame_gap:05}')))/255
  synthesizeframe(frame1, frame2)

In [None]:
demo_interpolate('c2', 236, frame_gap=1)