In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp -r /content/drive/MyDrive/AI/TextToAnimation/modules /content/modules

In [None]:
# Import

import torch
from torch import nn
import torch.nn.functional as F

import math
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

import numpy as np

import os
from PIL import Image

print(torch.__version__)

In [None]:
import sys
print (sys.version)

In [None]:
from modules.TextProcessing import get_description_from, get_text_parameters
from modules.DataLoading import LabeledAnimationLoader, create_loaders
from modules.PlotAnimation import plot_sample, plot_samples, plot_edge_frames

In [None]:
import arguments

# Constants
## Data parameters
dataset_root = arguments.dataset_root
parameters_save_root = arguments.parameters_save_root
train_size = arguments.train_size
validation_size = arguments.validation_size
ignore_size = arguments.ignore_size
batch_size = arguments.batch_size

image_channels = 3
transformed_size = 32

## Network parameters
embedding_dim = 4
# Size is set to one because I wanted to train it on only two animations
text_encoder_hidden_size = 1 #720
text_encoder_hidden_layers_count = 1

hidden_channels = 16

## Sequences parameters
animation_limit = 5

## Training parameters
learning_rate = 0.01
lambda_recon = 10 #50
lambda_seq = 1
lambda_frames_count = 0
validation_period = 25

In [None]:
# Load data
data_set = LabeledAnimationLoader(preprocessed_data_path=dataset_root)

# Text tokenizer
tokenizer = get_tokenizer('basic_english') # private

all_descriptions = data_set.get_descriptions() # private

print(all_descriptions)

padding_word, encode, count_of_words = get_text_parameters(tokenizer, all_descriptions)

for i in range(4):
    print(encode(all_descriptions[i]))

In [None]:
train_loader, val_loader = create_loaders(data_set, padding_word, batch_size, train_size, validation_size, ignore_size)

In [None]:
for n, samples in enumerate(train_loader):
    for i in range(batch_size):
        print(i)
        plot_sample(samples, i)

for n, samples in enumerate(val_loader):
    plot_samples(samples, validation_size)
        
#for n, samples in enumerate(val_loader):
#    print(f"{samples['reference'].size() = }")
#    for i in range(validation_size):
#        print(i)
#        plot_sample(samples, i)

In [None]:
# Setup GPU training if available

device = ""
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

In [None]:
# Images Decoder block
from modules.ImagesDecoders import InterpolationDecoder

# Encoder blocks
from modules.Encoders import TextEncoder
from modules.ImageNetworkBlocks import ContractingBlock, FeatureMapBlock, FeatureExchangeBlock
from modules.ImagesDecoders import ContractingPath, ExpandingPath
from modules.Discriminators import ThreeImagesAndDescriptionDiscriminator as Discriminator

In [None]:
# Generator
class Generator(nn.Module):
    def __init__(self, count_of_words, embedding_dim, padding_word, h_size, image_channels, hidden_channels, animation_limit):
        super(Generator, self).__init__()
        
        self.text_encoder = TextEncoder(
            count_of_words, 
            embedding_dim, 
            padding_word, 
            h_size, 
            1)
        
        self.contractingPath = ContractingPath(image_channels, image_channels, hidden_channels, h_size)
        
        self.featureExchangeStart = FeatureExchangeBlock(2048, (128, 4, 4), h_size)
        self.featureExchangeEnd = FeatureExchangeBlock(2048, (128, 4, 4), h_size)
        
        self.expandingPathStart = ExpandingPath(image_channels, image_channels, hidden_channels, h_size)
        self.expandingPathEnd = ExpandingPath(image_channels, image_channels, hidden_channels, h_size)
        
    def forward(self, description, reference):
        h = self.text_encoder(description)
        
        # Size of h: [1, batches_size, layer_size]. The first dimension is layer index, there is only one layer now.
        h = h[0]     
        
        # Fold -> feature exchange start -> Unfold
        #      -> feature exchange end   -> Unfold
        
        x0, x1, x2, x3 = self.contractingPath(reference, h)
        
        xfcStart, _ = self.featureExchangeStart(x3, h)
        xfcEnd, _ = self.featureExchangeEnd(x3, h)
        
        start = self.expandingPathStart(x0, x1, x2, xfcStart)
        end = self.expandingPathStart(x0, x1, x2, xfcEnd)
        
        return start, end

    def device(self):
        return next(self.parameters()).device

In [None]:
# Loss function

def get_gen_loss(
    gen, 
    disc, 
    start_real,
    end_real,
    description, 
    reference,
    adv_criterion, 
    recon_criterion, 
    lambda_recon,
    discription_original=None):
    
    start_fake, end_fake = gen(description, reference)
    evaluation = disc(description, reference, start_fake, end_fake)
    adv_loss = adv_criterion(evaluation, torch.ones_like(evaluation))
    recon_loss = lambda_recon * (recon_criterion(start_fake, start_real) + recon_criterion(end_fake, end_real))
    gen_loss = adv_loss + recon_loss
    
    return gen_loss

In [None]:
# CHANGE!! +

class TrainingData:
    def __init__(self, samples):
        self.discription = samples['desc_encoded'].to(device)
        self.discription_original = samples['description']
        images = samples['images'].to(device)
        
        # NOTE! There could be problems with "squeeze()" function. I use it to remove one dimension after picking an image
        # from a sequenc. But if there other important dimension with the size of 1, squeezing will also delete it
        # TODO: need to rewrite it normally.
        self.reference = torch.index_select(images, dim=0, index=torch.tensor([0]).to(device)).squeeze()
        self.start_real = torch.index_select(images, dim=0, index=torch.tensor([1]).to(device)).squeeze()
        self.end_real = torch.index_select(images, dim=0, index=torch.tensor([2]).to(device)).squeeze()

In [None]:
# Display data functions

def plot_generated_samples(generator, description_encoded, description, reference, start_frame, end_frame, demos_to_show):
    start_fake, end_fake = generator(description_encoded, reference)
    
    start_fake = start_fake.detach()
    end_fake = end_fake.detach()
    
    print(f'{reference.size() =}')
    
    plot_edge_frames(reference, start_frame, end_frame, start_fake, end_fake, description, demos_to_show)

In [None]:
# Network initializing

generator = Generator(
    count_of_words, 
    embedding_dim, 
    padding_word, 
    text_encoder_hidden_size, 
    image_channels, 
    hidden_channels, 
    animation_limit).to(device)

optimizer_generator = torch.optim.Adam(generator.parameters(), lr=learning_rate)

discriminator = Discriminator( 
    count_of_words, 
    embedding_dim, 
    padding_word, 
    text_encoder_hidden_size, 
    image_channels,
    hidden_channels).to(device)
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=learning_rate * 0.01)

adv_criterion = nn.BCEWithLogitsLoss() 
recon_criterion = nn.L1Loss() 

In [None]:

# parameters_load_root = parameters_save_root#'D:\AI\Parameters\SameImages_3'
# discriminator.load_state_dict   (torch.load(os.path.join(parameters_load_root, "discriminator"), map_location=torch.device('cpu')))
# generator.load_state_dict(torch.load(os.path.join(parameters_load_root, "generator"), map_location=torch.device('cpu')))


In [None]:
def TrainDiscriminator(samples):
    t = TrainingData(samples)

    with torch.no_grad():
        start_fake, end_fake = generator(t.discription, t.reference)
        start_fake = start_fake.detach()
        end_fake = end_fake.detach()

    optimizer_discriminator.zero_grad()

    disc_fake_hat = discriminator(t.discription, t.reference, start_fake, end_fake)
    disc_fake_loss = adv_criterion(disc_fake_hat, torch.zeros_like(disc_fake_hat))
    disc_real_hat = discriminator(t.discription, t.reference, t.start_real, t.end_real)
    disc_real_loss = adv_criterion(disc_real_hat, torch.ones_like(disc_real_hat))
    disc_loss = (disc_fake_loss + disc_real_loss) / 2

    if disc_loss > 0.5:
        disc_loss.backward() # Update gradients
        optimizer_discriminator.step() # Update optimizer
    
    return t, disc_loss

def TrainGenerator(samples):    
    t = TrainingData(samples)

    optimizer_generator.zero_grad()
    
    gen_loss = get_gen_loss(
        generator, 
        discriminator, 
        t.start_real, 
        t.end_real, 
        t.discription, 
        t.reference,  
        adv_criterion, 
        recon_criterion, 
        lambda_recon
    )                

    gen_loss.backward() # Update gradients
    optimizer_generator.step() # Update optimizer
    
    return t, gen_loss

def Validate(validation_samples):
    val_t = TrainingData(validation_samples)

    validation_gen_loss = get_gen_loss(
        generator, 
        discriminator, 
        val_t.start_real, 
        val_t.end_real, 
        val_t.discription, 
        val_t.reference, 
        adv_criterion, 
        recon_criterion, 
        lambda_recon
    ) 
    
    return val_t, validation_gen_loss

In [None]:
# Training
epoch = 0
torch.set_printoptions(precision=2, linewidth=200)

disc_image_training_iterations = 5 #20
gen_training_iterations = 1

disc_loss = 1

while True:
    
    # Training the discriminators
    for _ in range(disc_image_training_iterations):
        for n, samples in enumerate(train_loader):
            t, disc_loss = TrainDiscriminator(samples)
        
    # Training the generator
    for _ in range(gen_training_iterations):
        for n, samples in enumerate(train_loader):
            t, gen_loss = TrainGenerator(samples)
    
    print(f"[GEN] Epoch: {epoch} Loss D_img.: {disc_loss} Loss G.: {gen_loss}")#, end = "\r")
    
    if epoch % validation_period == 0:
        
        print(f"Train losses. [GEN] Epoch: {epoch}")
        print(f"Loss D_img.: {disc_loss:.2f}")
        print(f"Loss G.: {gen_loss:.2f}.")
            
        torch.save(discriminator.state_dict(), os.path.join(parameters_save_root, "discriminator"))
        torch.save(generator.state_dict(), os.path.join(parameters_save_root, "generator"))

        validation_samples = next(iter(val_loader))  
        val_t, validation_gen_loss = Validate(validation_samples)
        
        print()
        print(f"[GEN] Epoch: {epoch} Val loss G.: {validation_gen_loss}")
    
        # Display training frames
        
        print("=================================")
        print("Test samples")
        
        # only the last
        plot_generated_samples(generator, t.discription, t.discription_original, t.reference, t.start_real, t.end_real, batch_size)
        
        # Display validation frames
        print("=================================")
        print("Validation samples")
        plot_generated_samples(generator, val_t.discription, val_t.discription_original, val_t.reference, val_t.start_real, val_t.end_real, validation_size)
    
    epoch = epoch + 1

In [None]:
# torch.save(image_discriminator.state_dict(), os.path.join(parameters_save_root, "image_discriminator"))
# torch.save(sequence_discriminator.state_dict(), os.path.join(parameters_save_root, "sequence_discriminator"))
# torch.save(generator.state_dict(), os.path.join(parameters_save_root, "generator"))