In [None]:
!rm -r /workspace/content/frames_4/

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip -q drive/MyDrive/Colab\ Notebooks/GameGPT/frames_3.zip -d "/content/"

In [None]:
import string
import numpy as np
import json
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import math
import os
import gc
import random

In [None]:
# Check if CUDA is available
print("CUDA Available:", torch.cuda.is_available())

# Print the number of GPUs available
print("Number of GPUs:", torch.cuda.device_count())

# Print the name of the GPU
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
parent_dir = "drive/MyDrive/Colab Notebooks/GameGPT/"
parent_dir = ""
content_dir = f"/content/{parent_dir}"
content_dir = f"content/{parent_dir}"


In [None]:
datasetnames = ('content/frames_7', 'moves_7')

# Read Movements on each frame

In [None]:
moves_file_path = f'{content_dir}{datasetnames[1]}.json'

# Load moves from the JSON file
with open(moves_file_path, 'r') as json_file:
    moves = json.load(json_file)
value_to_int_mapping = {'Q': 0}
unique_values = [value for value in set(moves.values()) if value != 'Q']
value_to_int_mapping.update({value: i + 1 for i, value in enumerate(unique_values)})
int_to_value_mapping = {v: k for k, v in value_to_int_mapping.items()}

override = True
if override:
    value_to_int_mapping = {'Q': 0, 'L': 1, 'N': 2, 'R': 3}
    int_to_value_mapping = {0: 'Q', 1: 'L', 2: 'N', 3: 'R'}
    
# Create a new dictionary with integer values
new_moves = {key: value_to_int_mapping[value] for key, value in moves.items()}

In [None]:
value_to_int_mapping

In [None]:
int_to_value_mapping

# Create Train Dataset

*  frames: list of previous {seq_len} frames
*  moves: list of input keys, one for each frame
*  target_frames: for every {seq_len} frames there will be a target_frame

In [None]:
class CustomDataset(Dataset):
    def __init__(self, data_folder, moves_dict, seq_len, transform=None, transform_target=None):
        self.data_folder = data_folder
        self.moves_dict = moves_dict
        self.transform = transform
        self.transform_target = transform_target
        self.seq_len = seq_len

        # Get a list of all image filenames in the data folder
        self.image_filenames = [filename for filename in os.listdir(data_folder) if filename.endswith('.png')]

    def __len__(self):
        return len(self.image_filenames) - 2  # We subtract 5 for the sequence length

    def __getitem__(self, idx):
        # Get 5 consecutive frames and their moves
        frames = [Image.open(os.path.join(self.data_folder, f'frame_{idx + i}.png')) for i in range(self.seq_len) if idx + i < len(self.image_filenames) - 2]
        moves = [self.moves_dict[f"{idx + i}"] for i in range(self.seq_len) if idx + i < len(self.image_filenames) - 2]

        # Apply transformations to each frame
        if self.transform is not None:
            frames = [self.transform(frame) for frame in frames]

        # Prepare the target frames for each frame in the sequence
        target_frames = [Image.open(os.path.join(self.data_folder, f'frame_{idx + 1 + i}.png')) for i in range(self.seq_len) if idx + 1 + i < len(self.image_filenames) - 1]

        # Apply transformations to the target frame
        if self.transform is not None:
            target_frames = [self.transform_target(frame) for frame in target_frames]

        if len(frames) == self.seq_len and len(frames) == len(moves) == len(target_frames):
          frames = torch.stack(frames)
          moves = torch.tensor(moves)
          target_frames = torch.stack(target_frames)
        else:
          return self.__getitem__(idx - 1)

        return {'frames': frames, 'moves': moves, 'target_frames': target_frames}

# Define a transform for preprocessing images
data_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    #AddInverseBellShapeNoise(max_noise=0.7, sigma=0.3),
    transforms.ToTensor(),
])
data_transform_target = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# Model Definition
Basic Image To Vector Encoder

In [None]:
class ImageToVector(nn.Module):
    def __init__(self, input_channels, image_size, conv_size, d_model, bottleneck_channels, seq_len, dropout):
        super(ImageToVector, self).__init__()
        self.seq_len = seq_len
        
        self.encoder1 = self.conv_block(input_channels, conv_size)
        self.encoder2 = self.conv_block(conv_size, conv_size * 2)
        self.encoder3 = self.conv_block(conv_size * 2, conv_size * 4)
        self.encoder4 = self.conv_block(conv_size * 4, conv_size * 8)
        
        # Bottleneck
        self.bottleneck = self.conv_block(conv_size * 8, conv_size * 16)
        
        self.flatten = nn.Flatten()
        
        # Fully connected layer after bottleneck
        self.fc = nn.Linear(self._get_flattened_size(image_size), d_model)
        self.dropout = nn.Dropout(dropout)
    
    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def _get_flattened_size(self, image_size):
        # Calculate the size of the flattened output after all pooling layers
        size = image_size // 2**4  # 4 pooling layers, each with stride 2
        return (size * size * (conv_size * 16))  # channels * height * width
         
    def reparameterize(self, mu, log_var):
        """Reparameterization trick for sampling from a Gaussian distribution."""
        std = torch.exp(0.5 * log_var)  # Standard deviation from log-variance
        epsilon = torch.randn_like(std)  # Random noise
        return mu + epsilon * std

    def forward(self, x):
        # Encoder path
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(F.max_pool2d(enc1, 2))
        enc3 = self.encoder3(F.max_pool2d(enc2, 2))
        enc4 = self.encoder4(F.max_pool2d(enc3, 2))
        
        # Bottleneck
        bottleneck = self.bottleneck(F.max_pool2d(enc4, 2))

        # Flatten and pass through fully connected layer   
        x = self.flatten(bottleneck)
        
         # Compute mean and log-variance
        x = self.fc(x)
        x = self.dropout(x)

        return x, bottleneck, [enc1, enc2, enc3, enc4] # Return residuals for later use in decoder

Input Vector Positional Encoding

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len):  # Dynamically set d_model
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_seq_len, d_model).to(device)
        position = torch.arange(0, max_seq_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1)]

Self Attention Layer

In [None]:
class Attention(nn.Module):
    def __init__(self, d_model, n_head, seq_len, dropout):
        super(Attention, self).__init__()
        
        assert d_model % n_head == 0, "d_model must be divisible by n_head"
        self.c_attn = nn.Linear(d_model, 3 * d_model)
        self.linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)
        
        self.register_buffer('tril', torch.tril(torch.ones(seq_len, seq_len)))
        self.d_model = d_model
        self.n_head = n_head

    def forward(self, x):
        B, T, C = x.shape

        q, k, v = self.c_attn(x).split(self.d_model, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        scale = C ** -0.5
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
        attn_scores = attn_scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)

        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
        output = self.dropout(self.linear(attn_output))  # Apply final linear transformation and dropout
        
        return self.norm(output + x)  # Apply residual connection and normalization

In [None]:
class StackedAttention(nn.Module):
    def __init__(self, d_model, n_head, seq_len, num_layers, dropout):
        super(StackedAttention, self).__init__()

        self.attention_layers = nn.ModuleList(
            [Attention(d_model, n_head, seq_len, dropout) for _ in range(num_layers)]
        )

    def forward(self, x):
        for attn_layer in self.attention_layers:
            x = attn_layer(x)  # Apply each attention layer sequentially
        return x

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, d_model, n_head, seq_len_q, seq_len_k, dropout):
        super(CrossAttention, self).__init__()

        assert d_model % n_head == 0, "d_model must be divisible by n_head"
        
        self.c_attn_q = nn.Linear(d_model, d_model)
        self.c_attn_k = nn.Linear(d_model, d_model)
        self.c_attn_v = nn.Linear(d_model, d_model)
        self.out_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)

        # Attention mask
        self.register_buffer('mask', torch.tril(torch.ones(seq_len_q, seq_len_k)))
        self.d_model = d_model
        self.n_head = n_head
        
    def forward(self, queries, keys, values):
        B, T_q, C = queries.shape
        _, T_k, _ = keys.shape
        
        # Apply linear projections separately
        q = self.c_attn_q(queries)  # (B, T_q, C)
        k = self.c_attn_k(keys)     # (B, T_k, C)
        v = self.c_attn_v(values)   # (B, T_k, C)
        
        # Reshape for multi-head attention
        q = q.view(B, T_q, self.n_head, C // self.n_head).transpose(1, 2)  # (B, n_head, T_q, C // n_head)
        k = k.view(B, T_k, self.n_head, C // self.n_head).transpose(1, 2)  # (B, n_head, T_k, C // n_head)
        v = v.view(B, T_k, self.n_head, C // self.n_head).transpose(1, 2)  # (B, n_head, T_k, C // n_head)
        
        # Scaled dot-product attention
        scale = C ** -0.5
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale  # (B, n_head, T_q, T_k)
        
        # Apply attention mask
        attn_scores = attn_scores.masked_fill(self.mask[:T_q, :T_k] == 0, float('-inf'))
        
        attn_weights = F.softmax(attn_scores, dim=-1)  # (B, n_head, T_q, T_k)
        attn_output = torch.matmul(attn_weights, v)  # (B, n_head, T_q, C // n_head)
        
        # Concatenate heads and apply final linear layer
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T_q, C)  # (B, T_q, C)
        output = self.out_linear(attn_output)  # (B, T_q, C)
        
        # Apply dropout and residual connection
        output = self.dropout(output)
        return self.norm(output + queries)  # Residual connection and normalization

Basic Vector To Image decoder

In [None]:
class VectorToImage(nn.Module):
    def __init__(self, d_model, output_channels, image_size, conv_size, dropout):
        super(VectorToImage, self).__init__()
        self.output_channels = output_channels
        self.image_size = image_size
        self.conv_size = conv_size

        # Calculate the size of the feature map after the fully connected layer
        self.fc_out_size = conv_size * 16 * (image_size // 16) * (image_size // 16)

        # Linear layer to expand the feature vector to the size of the desired image
        self.fc = nn.Linear(d_model, self.fc_out_size)

        self.upconv4a = self.upconv(conv_size * 16, conv_size * 8)
        self.upconv4b = self.upconv(conv_size * 16, conv_size * 8)
        self.decoder4 = self.conv_block(conv_size * 24, conv_size * 8)
        
        self.upconv3 = self.upconv(conv_size * 8, conv_size * 4)
        self.decoder3 = self.conv_block(conv_size * 8, conv_size * 4)
        
        self.upconv2 = self.upconv(conv_size * 4, conv_size * 2)
        self.decoder2 = self.conv_block(conv_size * 4, conv_size * 2)
        
        self.upconv1 = self.upconv(conv_size * 2, conv_size)
        self.decoder1 = self.conv_block(conv_size * 2, conv_size)
        
        self.out_conv = nn.Conv2d(conv_size, output_channels, kernel_size=1)
        self.dropout = nn.Dropout(dropout)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def upconv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x, bottleneck, residuals):
        # x has shape [BT, C] where BT = batch size * timesteps

        enc4, enc3, enc2, enc1 = residuals[3], residuals[2], residuals[1], residuals[0]
        # Pass through the fully connected layer
        x = self.fc(x)
        
        # Reshape to match the dimensions for the first deconvolution layer
        x = x.view(-1, self.conv_size * 16, self.image_size // 16, self.image_size // 16)

        dec4a = self.upconv4a(x)
        dec4b = self.upconv4b(bottleneck)
        dec4 = torch.cat([dec4a, dec4b, enc4], dim=1)
        dec4 = self.decoder4(dec4)
        
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat([dec3, enc3], dim=1)
        dec3 = self.decoder3(dec3)
        
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat([dec2, enc2], dim=1)
        dec2 = self.decoder2(dec2)
        
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat([dec1, enc1], dim=1)
        dec1 = self.decoder1(dec1)
        
        out = self.out_conv(dec1)
        
        out= self.dropout(out)
        out = torch.tanh(out)
 
        return out


In [None]:
class AddNoise(nn.Module):
    def __init__(self, seq_len):
        super(AddNoise, self).__init__()
        self.seq_len = seq_len
        
    def forward(self, x):
        BT, d_model = x.size()
        B = int(BT / self.seq_len)
        T = self.seq_len
        x = x.view(B, T, d_model)
        
        delta_noise = torch.empty(B, T, d_model, device=x.device).uniform_(-1, 1)
        cumulative_noise = torch.zeros(B, T, d_model, device=x.device)
        normal_samples = torch.empty(B, T, d_model, device=x.device).uniform_(0.0, 1.0)
        mask = (normal_samples > (1 - 1 / T)).float()  # Mask where normal distribution value > 1/T       
        delta_noise *= mask
        for t in reversed(range(T)):
            if t == T-1:
                cumulative_noise[:, t] = delta_noise[:, t] 
            else:
                cumulative_noise[:, t] = delta_noise[:, t] + cumulative_noise[:, t+1]

        cumulative_noise.clip(-1, 1)
        
        x += cumulative_noise
        x = x.view(-1, d_model)
        return x

Final model

In [None]:
class GameModel(nn.Module):
    def __init__(self, forward_type, d_model, image_size, channels, num_input_tokens, seq_len, n_head, num_a_layers, ca_n_head, conv_size, bottleneck_channels, dropout):
        super(GameModel, self).__init__()

        self.forward_type = forward_type
        self.d_model = d_model
        self.seq_len = seq_len
        self.n_head = n_head

        # Embedding layer for input tokens
        self.input_token_embedding = nn.Embedding(num_input_tokens, d_model)
        
        # Image to vector transformation
        self.image_to_vector = ImageToVector(input_channels=channels, image_size=image_size, conv_size=conv_size, d_model=d_model, bottleneck_channels=bottleneck_channels, seq_len=seq_len,dropout=dropout)
       
        # Positional encoding
        self.positional_encoding = PositionalEncoding(d_model, max_seq_len=seq_len)
        
        # Cross-attention layer
        self.cross_attention = CrossAttention(d_model, ca_n_head, seq_len_q=seq_len, seq_len_k=seq_len, dropout=dropout)
        
        # Stacked attention layers
        self.stacked_attention = StackedAttention(d_model, n_head, seq_len, num_a_layers, dropout)

        self.add_noise = AddNoise(seq_len)
        
        # Output layer for image generation
        self.vector_to_image = VectorToImage(d_model=d_model, output_channels=channels, image_size=image_size, conv_size=conv_size, dropout=dropout)

        for param in self.input_token_embedding.parameters():
            param.requires_grad = False
            if param.requires_grad:
                print(f"Updating parameter {param}")
        
        for param in self.cross_attention.parameters():
            param.requires_grad = False
            if param.requires_grad:
                print(f"Updating parameter {param}")   

        if forward_type == "autoencoder":
            for param in self.stacked_attention.parameters():
                param.requires_grad = False
                if param.requires_grad:
                    print(f"Updating parameter {param}")

    def forward(self, k, x):
        if self.forward_type == "autoencoder":
            return self.forward_autoencoder(k, x)
        elif self.forward_type  == "diffusion":
            return self.forward_diffusion(k,x)
        return self.forward_full(k, x)
    
    def forward_autoencoder(self, k, x):
        # Extract batch size and number of timesteps
        B, T, _, _, _ = x.shape

        image_vectors, bottleneck, residuals = self.image_to_vector(x.view(-1, *x.shape[2:]))  # Flatten batch and timesteps
        y = image_vectors.view(B, T, -1)  # Reshape to (B, T, d_model)
        
        generated_images = self.vector_to_image(y.view(-1, self.d_model), residuals)  # Flatten timesteps
        generated_images = generated_images.view(B, T, *generated_images.shape[1:])  # Reshape to (B, T, C, H, W)

        return generated_images

    def forward_diffusion(self, k, x):
        # Extract batch size and number of timesteps
        B, T, _, _, _ = x.shape

        # Process all timesteps in parallel
        # Convert images to vectors
        image_vectors, bottleneck, residuals = self.image_to_vector(x.view(-1, *x.shape[2:]))  # Flatten batch and timesteps
        clean_vectors = image_vectors.view(B, T, -1)  # Reshape to (B, T, d_model) 

        y_noised = self.add_noise(clean_vectors) 

        # Embed input tokens
        k_embeddings = self.input_token_embedding(k)  # (B, T, d_model)
        
        # Apply cross-attention between token embeddings and self-attention output
        y_noised = self.cross_attention(queries=y_noised, keys=k_embeddings, values=k_embeddings)  # Cross-attention

        # Add positional encoding
        y_pos = self.positional_encoding(y_noised)
        
        # Apply stacked self-attention
        y_noised_pred = self.stacked_attention(y_pos)

        # Add residual connection and apply layer normalization
        #y_cleared = clean_vectors + y_noised_pred
        
        truth_y_noise = torch.cat((y_noised[:, 1:] + image_vectors[:, T:]), dim=1)  

        return y_noised_pred, truth_noise, clean_vectors
        
    def forward_full(self, k, x):
        # Extract batch size and number of timesteps
        B, T, _, _, _ = x.shape
        
        # Process all timesteps in parallel
        # Convert images to vectors
        image_vectors, bottleneck, residuals = self.image_to_vector(x.view(-1, *x.shape[2:]))  # Flatten batch and timesteps
        image_vectors = image_vectors.view(B, T, -1)  # Reshape to (B, T, d_model)
        
        # Embed input tokens
        k_embeddings = self.input_token_embedding(k)  # (B, T, d_model)
                
        # Apply cross-attention between token embeddings and self-attention output
        y = self.cross_attention(queries=image_vectors, keys=k_embeddings, values=k_embeddings)  # Cross-attention

        # Add positional encoding
        y_pos = self.positional_encoding(image_vectors)
        
        # Apply stacked self-attention
        y = self.stacked_attention(y_pos)
        
        # Add residual connection and apply layer normalization
        y = image_vectors + y
        
        # Generate images from vectors
        generated_images = self.vector_to_image(y.view(-1, self.d_model), bottleneck, residuals)  # Flatten timesteps
        generated_images = generated_images.view(B, T, *generated_images.shape[1:])  # Reshape to (B, T, C, H, W)
        
        return generated_images + x 

In [None]:
seq_len = 5
d_model = 512
image_size = 128
conv_size = 64
bottleneck_channels = 16
n_head = 8
num_a_layers = 4
ca_n_head = 1
num_input_tokens = 4
dropout = 0.3
channels = 3

In [None]:
torch.cuda.empty_cache()
gc.collect()
model = GameModel(forward_type="full", d_model=d_model, image_size=image_size, channels=channels, num_input_tokens=num_input_tokens, seq_len=seq_len, n_head=n_head, num_a_layers=num_a_layers, ca_n_head=ca_n_head, conv_size=conv_size, bottleneck_channels=bottleneck_channels, dropout=dropout)
model.to(device)

In [None]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")
print(f"Memory Allocated: {torch.cuda.memory_allocated(device) / (1024**3):.2f} GB")
print(f"Memory Cached: {torch.cuda.memory_reserved(device) / (1024**3):.2f} GB")

Preload Model

In [None]:
model_path = f"{parent_dir}model_ggptunet_v1.pth"

In [None]:
model.load_state_dict(torch.load(model_path))

In [None]:
class VelocityMSELoss(nn.Module):
    def __init__(self):
        super(VelocityMSELoss, self).__init__()

    def forward(self, predicted_frames, target_frames):
        batch_size, seq_len, channels, height, width = predicted_frames.shape
        
        # Compute the MSE loss between corresponding frames
        mse_loss = F.mse_loss(predicted_frames, target_frames, reduction='none')  # [batch_size, seq_len, channels, height, width]
        
        # Compute differences between consecutive frames
        pred_diff = predicted_frames[:, 1:] - predicted_frames[:, :-1]  # [batch_size, seq_len-1, channels, height, width]
        target_diff = target_frames[:, 1:] - target_frames[:, :-1]  # [batch_size, seq_len-1, channels, height, width]
        
        # Compute absolute differences to use as weights
        target_diff_abs = torch.abs(target_diff)  # [batch_size, seq_len-1, channels, height, width]
        
        # Compute average weights across channels, height, and width
        weights = target_diff_abs.mean(dim=(2, 3, 4))  # [batch_size, seq_len-1]
        
        # Expand weights to match the shape of mse_loss
        weights = weights.unsqueeze(2).unsqueeze(3).unsqueeze(4)  # [batch_size, seq_len-1, 1, 1, 1]
        weights = weights.expand(-1, -1, channels, height, width)  # [batch_size, seq_len-1, channels, height, width]
        
        # Adjust weights to match the original frame sequence dimensions
        weights_full = torch.zeros_like(mse_loss)  # [batch_size, seq_len, channels, height, width]
        weights_full[:, 1:] = weights  # Apply weights to the corresponding frame pairs
        
        # Apply weights and compute the weighted MSE loss
        weighted_mse_loss = (mse_loss * weights_full).mean()  # Reduce to a single scalar
        
        return weighted_mse_loss

In [None]:
class KLMSELoss(nn.Module):
    def __init__(self):
        super(KLMSELoss, self).__init__()

    def forward(self, predicted_frames, target_frames, mu, log_var):
        batch_size, seq_len, channels, height, width = predicted_frames.shape
        
        # Compute the MSE loss between corresponding frames
        mse_loss = F.mse_loss(predicted_frames, target_frames, reduction='mean')  # [batch_size, seq_len, channels, height, width]
        
        kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        total_loss = 0.8 * mse_loss + 0.2 * kl_div
        return total_loss

Train Setup

In [None]:

velocity_criterion = VelocityMSELoss()
mae_criterion = nn.L1Loss()  # MAE (L1 Loss)
mse_criterion = nn.MSELoss()  # MSE (L2 Loss)
klmse_criterion = KLMSELoss()  # MSE (L2 Loss)
criterion = mse_criterion 
optimizer = optim.AdamW(model.parameters(), lr=0.0000001)

In [None]:
custom_dataset = CustomDataset(datasetnames[0], new_moves, seq_len=seq_len, transform=data_transform, transform_target=data_transform_target)
dataloader = DataLoader(custom_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)

In [None]:
torch.cuda.empty_cache()
gc.collect()

# Train

Train Loop

TODO: Val Loss Monitoring

In [None]:
# Training loop
model.train()
debug = False 
debug2 = False 

random_repeat = [1, 0]
iters_repeat = [0, 1]
num_epochs = len(random_repeat)
for epoch in range(num_epochs):
    total_loss = 0.0
    for batch_idx, x in enumerate(dataloader):
        frames = x['frames'].to(device)
        moves = x['moves'].to(device)
        target = x['target_frames'].to(device)

        if len(frames) > 0:
            if torch.rand(1).item() > random_repeat[epoch]:
                with torch.no_grad():
                    for j in range(0, iters_repeat[epoch]):
                        first_frame = frames[:, 0:1]  # Shape [batch_size, 1, channels, height, width]
                        predicted_frames = model(moves, frames)
                        # Take all predicted frames except the first one (predicted for t+1 to t+seq_len)
                        remaining_frames = predicted_frames[:, :-1]  # Shape [batch_size, seq_len-1, channels, height, width]
                        # Concatenate the first input frame with the predicted frames
                        frames = torch.cat((first_frame, remaining_frames), dim=1)
                        if debug:
                          for i in range(frames.size(1)):  # Iterating over channels
                             predicted_frame = frames[0][i]  # Get the i-th channel frame
                             print(predicted_frame.shape)  # Should be [128, 128]
                        
                             # Convert to NumPy array and adjust dimensions for PIL
                             predicted_np = (predicted_frame.cpu().detach().numpy() * 255).astype('uint8')
                            
                             # Since PIL expects [height, width, channels], we need to reorder
                             predicted_np = np.transpose(predicted_np, (1, 2, 0))  # [128, 128, 3] 
                        
                             # Create a PIL image and display it
                             image_pil = Image.fromarray(predicted_np)
                             display(image_pil)  # Use this for Jupyter notebooks
            #break
            # Forward pass
            output = model(moves, frames)

            if debug2:
                for i in range(output.size(1)):  # Iterating over channels
                    predicted_frame = output[0][i]  # Get the i-th channel frame
                    print(predicted_frame.shape)  # Should be [128, 128]
                    
                    # Convert to NumPy array and adjust dimensions for PIL
                    predicted_np = (predicted_frame.cpu().detach().numpy() * 255).astype('uint8')
                    
                    # Since PIL expects [height, width, channels], we need to reorder
                    predicted_np = np.transpose(predicted_np, (1, 2, 0))  # [128, 128, 3] 
                    
                    # Create a PIL image and display it
                    image_pil = Image.fromarray(predicted_np)
                    display(image_pil)  # Use this for Jupyter notebooks
                              
            # Compute the loss
            loss = criterion(output, target)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            max_norm = 1.0  # Clip gradients to have a maximum norm of 1.0
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            optimizer.step()
            
            if (batch_idx + 1) % 10 == 0:
            # Print the results
                print(f"Epoch [{epoch + 1}/{num_epochs}], "
                      f"Batch [{batch_idx + 1}/{len(dataloader)}], "
                      f"Loss: {loss.item():.8f}")
                
            if (batch_idx + 1) % 100 == 0:
                vloss = velocity_criterion(output, target)
                mae = mae_criterion(output, target)
                mse = mse_criterion(output, target)
                # Print the results
                print(f"Epoch [{epoch + 1}/{num_epochs}], "
                      f"Batch [{batch_idx + 1}/{len(dataloader)}], "
                      f"Loss: {loss.item():.8f}, "
                      f"VLoss: {vloss.item():.8f}, "
                      f"MAE: {mae.item():.8f}, "
                      f"MSE: {mse.item():.8f}")
              
            total_loss += loss.item()
            #torch.cuda.empty_cache()
            #gc.collect()
            
            # Save the model every 50 batches
            if (batch_idx + 1) % 1000 == 0:
              torch.save(model.state_dict(), f"{parent_dir}model_checkpoint_epoch{epoch + 1}_batch{batch_idx + 1}.pth")

    # Print average loss for the epoch
    
    average_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {average_loss:.8f}")
    torch.save(model.state_dict(), f"{parent_dir}model_epoch_ended_{epoch + 1}.pth")

print("Training complete!")

In [None]:
torch.save(model.state_dict(), f"{parent_dir}model_dgpt_v2.pth")

In [None]:
torch.cuda.empty_cache()
gc.collect()

# Play Game

In [None]:
class PlayDataset(Dataset):
    def __init__(self):
        self.current = Image.open(os.path.join(f"{datasetnames[0]}/frame_0.png"))
        self.moves = [2, 2, 2, 2, 2]
        self.frames = [self.current, self.current, self.current, self.current, self.current]

    def __len__(self):
        return 1000  # We subtract 5 for the sequence length

    def __getitem__(self):

        frames = [data_transform_p(frame) for frame in self.frames]
        frames = torch.stack(frames).unsqueeze(0)
        moves = torch.tensor(self.moves).unsqueeze(0)

        return {'frames': frames, 'moves': moves }

    def move(self, move, frame):
        self.moves = self.moves[1:] + [value_to_int_mapping[move]]
        self.frames = self.frames[1:] + [frame]

data_transform_p = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

In [None]:
play_dataset = PlayDataset()
model.eval()

In [None]:
import random
results = []
for i in range(0, 100):
    x = play_dataset.__getitem__()
    
    moves = x.get('moves').to(device)
    frames = x.get('frames').to(device)
    
    # Forward pass
    output = model(moves, frames)
    
    predicted_frames = output.squeeze(0).permute(0, 2, 3, 1)
    predicted_np = (predicted_frames.cpu().detach().numpy() * 255).astype('uint8')
    predicted = predicted_np[4]
    image_pil = Image.fromarray(predicted)
    movement = random.choice(["L", "N", "R"])
    play_dataset.move(movement, image_pil)
    new_size = (512, 512)  # Example new size
    image_resized = image_pil.resize(new_size, Image.NEAREST) 
    results.append(image_resized)


In [None]:
results[0].save("play.gif", save_all=True, append_images=results[1:], duration=0, loop=0)

In [None]:
from IPython.display import Image as ImageD
display(ImageD(filename="play.gif"))

# Play Manually

In [None]:
import random
play_dataset = PlayDataset()
model.eval()
results = []

In [None]:
movements = ["R"] * 5 +  ["L"] * 10 +  ["R"] * 10
for i, movement in enumerate(movements):
    x = play_dataset.__getitem__()
    
    moves = x.get('moves')
    frames = x.get('frames')
    target = x.get('target_frames')
    
    # Forward pass
    output = model(moves, frames)
    
    predicted_frames = output.squeeze(0).permute(0, 2, 3, 1)
    predicted_np = (predicted_frames.cpu().detach().numpy() * 255).astype('uint8')
    predicted = predicted_np[4]
    image_pil = Image.fromarray(predicted)
    play_dataset.move(movement, image_pil)
    new_size = (512, 512)  # Example new size
    image_resized = image_pil.resize(new_size, Image.NEAREST) 
    results.append(image_resized)


In [None]:
results[0].save("play.gif", save_all=True, append_images=results[1:], duration=0, loop=0)
from IPython.display import Image as ImageD
display(ImageD(filename="play.gif"))

# Play Train Moves Predicting frames

In [None]:
class PlayTrainDataset(Dataset):
    def __init__(self, data_folder, moves_dict, seq_len):
        self.data_folder = data_folder
        self.moves_dict = moves_dict

        self.seq_len = seq_len

        # Get a list of all image filenames in the data folder
        self.image_filenames = [filename for filename in os.listdir(data_folder) if filename.endswith('.png')]
        self.moves = [self.moves_dict[f"{i}"] for i in range(self.seq_len) if i < len(self.image_filenames) - 2]
        self.frames = [Image.open(os.path.join(self.data_folder, f'frame_{i}.png')) for i in range(self.seq_len) if i < len(self.image_filenames) - 2]
        self.counter = 4
        
    def __getitem__(self):

        frames = [data_transform_p(frame) for frame in self.frames]
        frames = torch.stack(frames).unsqueeze(0)
        moves = torch.tensor(self.moves).unsqueeze(0)

        return {'frames': frames, 'moves': moves }

    def move(self, frame):
        self.counter = self.counter + 1
        self.moves = self.moves[1:] + [self.moves_dict[f"{self.counter}"]]
        self.frames = self.frames[1:] + [frame]

data_transform_p = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

In [None]:
play_train_dataset = PlayTrainDataset(datasetnames[0], new_moves, seq_len=5)
model.eval()
new_size = (512, 512)
results = [Image.open(os.path.join(datasetnames[0], f'frame_{i}.png')).resize(new_size, Image.NEAREST)  for i in range(5)]

In [None]:
for i in range(0, 500):
    x = play_train_dataset.__getitem__()
    
    moves = x.get('moves').to(device)
    frames = x.get('frames').to(device)
    # Forward pass
    output = model(moves, frames)
    
    predicted_frames = output.squeeze(0).permute(0, 2, 3, 1)
    predicted_np = (predicted_frames.cpu().detach().numpy() * 255).astype('uint8')
    predicted = predicted_np[4]
    image_pil = Image.fromarray(predicted)
    #display(image_pil)
    
    play_train_dataset.move(image_pil)
  # Example new size
    image_resized = image_pil.resize(new_size, Image.NEAREST) 
    results.append(image_resized)


In [None]:
results[0].save("play.gif", save_all=True, append_images=results[1:], duration=0, loop=0)
from IPython.display import Image as ImageD
display(ImageD(filename="play.gif"))

# Predict next frame only

In [None]:
seq = 500
custom_dataset = CustomDataset(datasetnames[0], new_moves, seq_len=seq, transform=data_transform, transform_target=data_transform_target)
dataloader = DataLoader(custom_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)
results = []

In [None]:
for batch_idx, x in enumerate(dataloader):
    if batch_idx == 0:
        frames = x['frames'].to(device)
        moves = x['moves'].to(device)
        target = x['target_frames'].to(device)
     
        if len(frames) > 0:
            for i in range(0, frames.size(1), 5):
                frames2 = frames[:, i:i + 5, :, :, :]
                moves2 = moves[:, i:i + 5]
                # Forward pass
                output = model(moves2, frames2)
                predicted_frames = output.squeeze(0).permute(0, 2, 3, 1)
                predicted_np = (predicted_frames.cpu().detach().numpy() * 255).astype('uint8')
                predicted = predicted_np[4]
                image_pil = Image.fromarray(predicted)
                new_size = (512, 512)  # Example new size
                image_resized = image_pil.resize(new_size, Image.NEAREST) 
                results.append(image_resized)
    else:
        break


In [None]:
results[0].save("play.gif", save_all=True, append_images=results[1:], duration=1000, loop=0)
from IPython.display import Image as ImageD
display(ImageD(filename="play.gif"))