In [7]:
#@title Mount google drive
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [8]:
#@title Imports
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils import data as D
import os, re

In [None]:
pip install bitnet


In [None]:
pip install torchcfm
from torchcfm.conditional_flow_matching import *
from torchcfm.utils import plot_trajectories, torch_wrapper

# Data loading
- Need to implement way to load using DataLoader class of pytorch.
- Create way to create offset tensor set for "groud truth" embeddings to train

In [9]:
class SimData(D.Dataset):
    def __init__(self, file_paths):
        """
        Args:
            file_paths (list of str): List of paths to the .pt files.
        """
        self.file_paths = file_paths

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        # Load the tensor stored at the idx-th path
        # Check if file exists
        file_path = self.file_paths[idx]
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"No such file or directory: '{file_path}'")

        data = torch.load(file_path)
        # Ensure the data has the expected dimensions
        if data.size(0) < 2:
            raise ValueError(f"Data at index {idx} does not have the expected dimensions.")

        # Select the data and stepAhead based on the description
        data_current = data[:-1]  # All but the last time step
        stepAhead = data[1:]      # All but the first time step
        return data_current, stepAhead

In [10]:
folder_path = "/content/drive/MyDrive/ProteinBindingProject/encodedData" #@param
# Check if the directory exists
if not os.path.isdir(folder_path):
    raise FileNotFoundError(f"Directory '{folder_path}' does not exist.")
# Get the list of .pt files sorted by the number present at the beginning of the file name
pt_files = sorted([f for f in os.listdir(folder_path) if f.endswith('.pt')],
    key=lambda x: float(re.match(r'batch_(\d+)_', x).group(1)))
file_paths = []
for pt_file in pt_files:
    file_paths.append(os.path.join(folder_path, pt_file))

for file_path in file_paths:
    print(os.path.exists(file_path))

# Create dataset
dataset = SimData(file_paths)

# Create DataLoader
data_loader = D.DataLoader(dataset, batch_size=1, shuffle=False)



True
True
True
True


# Diffusion Module from open source implementation of Alphafold3

Below are diffusion modules to test for inference of next time point using diffusion, once this is achieved next step will be to implent KAN as predicitve network to create a more tracatable predictive module

In [None]:
#@title Initial Sucess Diffusion steps set at 100 for a model to fit in TPU system RAM size
#@markdown - Need to finalize padding size
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

def context_size():
  x = 1320 #@param
  return x

class GroupedQueryAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, num_groups):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.embed_dim = embed_dim
        self.num_heads = int(num_heads)  # Ensure num_heads is an integer
        self.num_groups = num_groups
        self.head_dim = embed_dim // self.num_heads  # Ensure head_dim is an integer

        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()
        assert embed_dim == self.embed_dim, "Input embedding dimension must match model embedding dimension"

        qkv = self.qkv_proj(x)
        qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch_size, num_heads, seq_len, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        assert seq_len % self.num_groups == 0, "seq_len must be divisible by num_groups"
        group_size = seq_len // self.num_groups

        # Group queries
        q_groups = q.view(batch_size, self.num_heads, self.num_groups, group_size, self.head_dim)  # (batch_size, num_heads, num_groups, group_size, head_dim)
        k_groups = k.view(batch_size, self.num_heads, self.num_groups, group_size, self.head_dim)
        v_groups = v.view(batch_size, self.num_heads, self.num_groups, group_size, self.head_dim)

        # Compute attention for each group
        attn_scores = torch.einsum('bhgqd,bhgkd->bhgqk', q_groups, k_groups)  # (batch_size, num_heads, num_groups, group_size, group_size)
        attn_scores = attn_scores / (self.head_dim ** 0.5)
        attn_probs = F.softmax(attn_scores, dim=-1)

        attn_output = torch.einsum('bhgqk,bhgvd->bhgqd', attn_probs, v_groups)  # (batch_size, num_heads, num_groups, group_size, head_dim)
        attn_output = attn_output.contiguous().view(batch_size, self.num_heads, seq_len, self.head_dim)
        attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.embed_dim)

        return self.out_proj(attn_output)

def modulate(normed_x, shift, scale):
    return normed_x * (1 + scale) + shift

class TransformerLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, num_groups, feedforward_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
        self.attention = GroupedQueryAttention(embed_dim, num_heads, num_groups)
        self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
        self.feedforward = nn.Sequential(
            nn.Linear(embed_dim, feedforward_dim),
            nn.ReLU(),
            nn.Linear(feedforward_dim, embed_dim)
        )
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(embed_dim, 6 * embed_dim, bias=True)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=2)
        x = x + gate_msa * self.attention(modulate(self.norm1(x), shift_msa, scale_msa))
        x = self.norm1(x)
        x = x + gate_mlp * self.feedforward(modulate(self.norm2(x), shift_mlp, scale_mlp))
        x = self.norm2(x)
        return x

class GeneticDiffusionModuleBlock(nn.Module):
    def __init__(self, channels: int, num_diffusion_steps: int = 100, training: bool = False, depth: int = 3):
        super(GeneticDiffusionModuleBlock, self).__init__()
        self.channels = channels
        assert channels % 8 == 0, "channels must be divisible by 64"
        num_heads = channels // 8
        self.num_diffusion_steps = num_diffusion_steps
        self.time_embeddings = nn.Parameter(torch.randn(num_diffusion_steps, context_size(), channels))
        self.training = training
        self.depth = depth
        self.noise_scale = nn.Parameter(torch.linspace(1.0, 0.01, num_diffusion_steps))

        # Custom transformer layers
        self.transformer_layers = nn.ModuleList([
            TransformerLayer(embed_dim=channels, num_heads=num_heads, num_groups=num_heads//2, feedforward_dim=channels * 4) for _ in range(3)
        ])

    def forward(self, x: Tensor = None, ground_truth: Tensor = None):
        batch_size, num_nodes, k = x.size()
        x_1 = x.clone()

        # Simulate the multi-step diffusion process
        for step in range(self.num_diffusion_steps):
            noise_level = self.noise_scale[step]  # Get the noise level for the current step
            noise = torch.randn_like(x) * noise_level  # Generate noise scaled by the noise level
            x_1 = x_1 + noise  # Add noise to the input

            c = self.time_embeddings[step] # Get the time embedding for the current step
            c = c.unsqueeze(0).repeat(batch_size, 1, 1)  # Shape: [batch_size, 1000, channels]

            # Apply custom transformer layers
            for transformer_layer in self.transformer_layers:
                x_1 = transformer_layer(x_1, c)


        if self.training and ground_truth is not None:
            loss = F.mse_loss(x_1, ground_truth)
            return x_1, loss

        return x_1

class GeneticDiffusion(nn.Module):
    def __init__(self, channels: int, num_diffusion_steps: int = 10, k: int = 64, embeddings: int = 384, training: bool = False, depth: int = 3):
        super(GeneticDiffusion, self).__init__()
        self.channels = channels
        self.num_diffusion_steps = num_diffusion_steps
        self.training = training
        self.depth = depth
        self.convlayers = nn.Conv1d(k*embeddings, channels, 1)
        # Layers
        self.layers = nn.ModuleList([
            GeneticDiffusionModuleBlock(channels, num_diffusion_steps, training, depth) for _ in range(depth)
        ])

    def forward(self, x: Tensor = None, ground_truth: Tensor = None):
        # Assuming input x shape is [batch_size, nodes, k, embeddings]
        batch_size, nodes, k, embeddings = x.size()
        # Flatten the k neighbors' embeddings per node into a 1D vector
        x = x.view(batch_size, nodes, -1)  # [batch_size, nodes, k * embeddings]
        # Pad nodes to ensure a total of context_size nodes per batch
        padding = torch.zeros(batch_size, context_size() - nodes, k * embeddings).to(x.device)
        if nodes < context_size():
            x = torch.cat([x, padding], dim=1)  # [batch_size, context_size, k * embeddings]

        x = x.permute(0, 2, 1)  # [batch_size, k * embeddings, context_size]
        x = self.convlayers(x)  # [batch_size, channels, context_size]
        x = x.permute(0, 2, 1)  # [batch_size, context_size, channels]
        # Apply GeneticDiffusionModuleBlock
        loss = None
        if self.training and ground_truth is not None:
            ground_truth = ground_truth.view(batch_size, nodes, -1)
            if nodes<context_size():
              ground_truth = torch.cat([ground_truth, padding], dim=1)
            ground_truth = ground_truth.permute(0, 2, 1)
            ground_truth = self.convlayers(ground_truth)
            ground_truth = ground_truth.permute(0, 2, 1)
            for layer in self.layers:
                x, loss = layer(x, ground_truth)
            return x, loss
        else:
            for layer in self.layers:
                x = layer(x)
            return x

# # Example usage with fixed number of nodes per batch
# batch_size = 10
# nodes = 50  # Fixed number of nodes
# k = 64
# embeddings = 384
# channels = 128  # Number of channels for the diffusion

# # Creating a dataset with a fixed number of nodes
# dummy_inputs = [torch.randn(batch_size, nodes, k, embeddings)]
# dummy_ground_truths = [torch.randn(batch_size, nodes, k, embeddings)]
# num_diffusion_steps = 100 #@param
# model = GeneticDiffusion(channels=channels, num_diffusion_steps=num_diffusion_steps, training=True, depth=3)

# for dummy_input, dummy_ground_truth in zip(dummy_inputs, dummy_ground_truths):
#     output, loss = model(dummy_input, dummy_ground_truth)
#     print(f"Output shape: {output.shape}")
#     print(f"Loss: {loss.item()}")

In [17]:
#@title Trying with Conv of nodes also
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

#@markdown Need to figure out best context scale size (must be a factor od 2000)
context_scale = 10 #@param

class GroupedQueryAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, num_groups):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.embed_dim = embed_dim
        self.num_heads = int(num_heads)  # Ensure num_heads is an integer
        self.num_groups = num_groups
        self.head_dim = embed_dim // self.num_heads  # Ensure head_dim is an integer

        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()
        assert embed_dim == self.embed_dim, "Input embedding dimension must match model embedding dimension"

        qkv = self.qkv_proj(x)
        qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch_size, num_heads, seq_len, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        assert seq_len % self.num_groups == 0, "seq_len must be divisible by num_groups"
        group_size = seq_len // self.num_groups

        # Group queries
        q_groups = q.view(batch_size, self.num_heads, self.num_groups, group_size, self.head_dim)  # (batch_size, num_heads, num_groups, group_size, head_dim)
        k_groups = k.view(batch_size, self.num_heads, self.num_groups, group_size, self.head_dim)
        v_groups = v.view(batch_size, self.num_heads, self.num_groups, group_size, self.head_dim)

        # Compute attention for each group
        attn_scores = torch.einsum('bhgqd,bhgkd->bhgqk', q_groups, k_groups)  # (batch_size, num_heads, num_groups, group_size, group_size)
        attn_scores = attn_scores / (self.head_dim ** 0.5)
        attn_probs = F.softmax(attn_scores, dim=-1)

        attn_output = torch.einsum('bhgqk,bhgvd->bhgqd', attn_probs, v_groups)  # (batch_size, num_heads, num_groups, group_size, head_dim)
        attn_output = attn_output.contiguous().view(batch_size, self.num_heads, seq_len, self.head_dim)
        attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.embed_dim)

        return self.out_proj(attn_output)

def modulate(normed_x, shift, scale):
    return normed_x * (1 + scale) + shift

class TransformerLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, num_groups, feedforward_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
        self.attention = GroupedQueryAttention(embed_dim, num_heads, num_groups)
        self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
        self.feedforward = nn.Sequential(
            nn.Linear(embed_dim, feedforward_dim),
            nn.ReLU(),
            nn.Linear(feedforward_dim, embed_dim)
        )
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(embed_dim, 6 * embed_dim, bias=True)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=2)
        x = x + gate_msa * self.attention(modulate(self.norm1(x), shift_msa, scale_msa))
        x = self.norm1(x)
        x = x + gate_mlp * self.feedforward(modulate(self.norm2(x), shift_mlp, scale_mlp))
        x = self.norm2(x)
        return x

class GeneticDiffusionModuleBlock(nn.Module):
    def __init__(self, channels: int, num_diffusion_steps: int = 100, training: bool = False, depth: int = 3):
        super(GeneticDiffusionModuleBlock, self).__init__()
        self.channels = channels
        assert channels % 8 == 0, "channels must be divisible by 64"
        num_heads = channels // 8
        self.num_diffusion_steps = num_diffusion_steps
        self.time_embeddings = nn.Parameter(torch.randn(num_diffusion_steps, 2000//context_scale, channels))
        self.training = training
        self.depth = depth
        self.noise_scale = nn.Parameter(torch.linspace(1.0, 0.01, num_diffusion_steps))

        # Custom transformer layers
        self.transformer_layers = nn.ModuleList([
            TransformerLayer(embed_dim=channels, num_heads=num_heads, num_groups=num_heads//2, feedforward_dim=channels * 4) for _ in range(3)
        ])

    def forward(self, x: Tensor = None, ground_truth: Tensor = None):
        batch_size, num_nodes, k = x.size()
        x_1 = x.clone()

        # Simulate the multi-step diffusion process
        for step in range(self.num_diffusion_steps):
            noise_level = self.noise_scale[step]  # Get the noise level for the current step
            noise = torch.randn_like(x) * noise_level  # Generate noise scaled by the noise level
            x_1 = x_1 + noise  # Add noise to the input

            c = self.time_embeddings[step] # Get the time embedding for the current step
            c = c.unsqueeze(0).repeat(batch_size, 1, 1)  # Shape: [batch_size, 1000, channels]

            # Apply custom transformer layers
            for transformer_layer in self.transformer_layers:
                x_1 = transformer_layer(x_1, c)


        if self.training and ground_truth is not None:
            loss = F.mse_loss(x_1, ground_truth)
            return x_1, loss

        return x_1

class GeneticDiffusion(nn.Module):
    def __init__(self, channels: int, num_diffusion_steps: int = 10, k: int = 64, embeddings: int = 384, training: bool = False, depth: int = 3):
        super(GeneticDiffusion, self).__init__()
        self.channels = channels
        self.num_diffusion_steps = num_diffusion_steps
        self.training = training
        self.depth = depth

        # 2D convolutional layer to transform the input tensor from [10, 2000, 24576] to [10, 200, 128]
        self.conv = nn.Conv2d(in_channels=1,
                                out_channels=1,
                                kernel_size=(context_scale, 192),
                                stride=(context_scale, 192),
                                padding=0)
        # Layers
        self.layers = nn.ModuleList([
            GeneticDiffusionModuleBlock(channels, num_diffusion_steps, training, depth) for _ in range(depth)
        ])

    def forward(self, x: Tensor = None, ground_truth: Tensor = None):
        # Assuming input x shape is [batch_size, nodes, k, embeddings]
        batch_size, nodes, k, embeddings = x.size()
        # Flatten the k neighbors' embeddings per node into a 1D vector
        x = x.view(batch_size, nodes, -1)  # [batch_size, nodes, k * embeddings]
        # Pad nodes to ensure a total of context_size nodes per batch
        padding = torch.zeros(batch_size, 2000 - nodes, k * embeddings).to(x.device)
        if nodes < 2000:
            x = torch.cat([x, padding], dim=1)  # [batch_size, context_size, k * embeddings]
            if self.training and ground_truth is not None:
              ground_truth = ground_truth.view(batch_size, nodes, -1)
              ground_truth = torch.cat([ground_truth, padding], dim=1)
              ground_truth = ground_truth.unsqueeze(1)  # Add channel dimension
              ground_truth = self.conv(ground_truth)
              # Reshape back to desired shape: [batch_size, height, width]
              ground_truth = ground_truth.squeeze(1)


        # Reshape to fit Conv2d input: [batch_size, channels, height, width]
        x = x.unsqueeze(1)  # Add channel dimension
        x = self.conv(x)
        # Reshape back to desired shape: [batch_size, height, width]
        x = x.squeeze(1)  # Remove channel dimension

        # Apply GeneticDiffusionModuleBlock
        loss = None
        if self.training and ground_truth is not None:
            for layer in self.layers:
                x, loss = layer(x, ground_truth)
            return x, loss
        else:
            for layer in self.layers:
                x = layer(x)
            return x

# Example usage with fixed number of nodes per batch
batch_size = 10
nodes = 1929  # Fixed number of nodes
k = 64
embeddings = 384
channels = 128  # Number of channels for the diffusion

# Creating a dataset with a fixed number of nodes
dummy_inputs = [torch.randn(batch_size, nodes, k, embeddings)]
dummy_ground_truths = [torch.randn(batch_size, nodes, k, embeddings)]
num_diffusion_steps = 100 #@param
model = GeneticDiffusion(channels=channels, num_diffusion_steps=num_diffusion_steps, training=True, depth=3)

for dummy_input, dummy_ground_truth in zip(dummy_inputs, dummy_ground_truths):
    output, loss = model(dummy_input, dummy_ground_truth)
    print(f"Output shape: {output.shape}")
    print(f"Loss: {loss.item()}")



Output shape: torch.Size([10, 200, 128])
Loss: 1.3380261659622192


This diffuison model with convulution of the nodes was able to run on about 100 gigs with the batch size that I have currently of 10

# Diffusion Module from with BitNet implementation

Below are diffusion modules to test BitNet implemenation

In [None]:
#@markdown - Need to finalize padding size
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

pip install bitnet
from bitnet import BitFeedForward, BitMGQA

def context_size():
  x = 2000 #@param
  return x

class Transformer(nn.Module):
    """
    Transformer module that applies multi-head attention and feed-forward layers.

    Args:
        dim (int): The dimension of the input and output tensors.
        heads (int): The number of attention heads.
        depth (int): The number of transformer layers.
        ff_mult (int, optional): The multiplier for the hidden dimension in the feed-forward layers.
            Defaults to 2.
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.

    Attributes:
        layers (nn.ModuleList): List of multi-head attention layers.
        ffn_layers (nn.ModuleList): List of feed-forward layers.

    """

    def __init__(
        self, dim: int, heads: int, depth: int, ff_mult: int = 2, *args, **kwargs
    ):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.ffn_layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(BitMGQA(dim, heads, *args, **kwargs))

            self.ffn_layers.append(
                BitFeedForward(
                    dim,
                    dim,
                    ff_mult,
                    swish=True,
                    post_act_ln=True,
                    dropout=0.1,
                ),
            )

    def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
        skip = x
        for attn, ffn in zip(self.layers, self.ffn_layers):
            x, _ = attn(x, x, x, is_causal=True, *args, **kwargs)
            x = x + skip
            x = ffn(x) + x
        return x

class Diffusion(nn.Module):
    def __init__(self, channels: int, num_diffusion_steps: int = 10, k: int = 64, embeddings: int = 384, training: bool = False, depth: int = 3):
        super(Diffusion, self).__init__()
        self.channels = channels
        self.num_diffusion_steps = num_diffusion_steps
        self.training = training
        self.depth = depth
        self.conv = nn.Conv1d(k*embeddings, channels, 1)
        # Layers
        self.layers = nn.ModuleList([
            GeneticDiffusionModuleBlock(channels, num_diffusion_steps, training, depth) for _ in range(depth)
        ])

    def forward(self, x: Tensor = None, ground_truth: Tensor = None):
        # Assuming input x shape is [batch_size, nodes, k, embeddings]
        batch_size, nodes, k, embeddings = x.size()
        # Flatten the k neighbors' embeddings per node into a 1D vector
        x = x.view(batch_size, nodes, -1)  # [batch_size, nodes, k * embeddings]
        # Pad nodes to ensure a total of context_size nodes per batch
        padding = torch.zeros(batch_size, context_size() - nodes, k * embeddings).to(x.device)
        if nodes < context_size():
            x = torch.cat([x, padding], dim=1)  # [batch_size, context_size, k * embeddings]

        x = x.permute(0, 2, 1)  # [batch_size, k * embeddings, context_size]
        x = self.conv(x)  # [batch_size, channels, context_size]
        x = x.permute(0, 2, 1)  # [batch_size, context_size, channels]
        # Apply GeneticDiffusionModuleBlock
        loss = None
        if self.training and ground_truth is not None:
            ground_truth = ground_truth.view(batch_size, nodes, -1)
            if nodes<context_size():
              ground_truth = torch.cat([ground_truth, padding], dim=1)
            ground_truth = ground_truth.permute(0, 2, 1)
            ground_truth = self.convlayers(ground_truth)
            ground_truth = ground_truth.permute(0, 2, 1)
            for layer in self.layers:
                x, loss = layer(x, ground_truth)
            return x, loss
        else:
            for layer in self.layers:
                x = layer(x)
            return x

# Parameters to Test

In [18]:
channels = 128 #@param
num_diffusion_steps = 100 #@param
depth = 3 #@param
learnRate=1e-3 #@param

model = GeneticDiffusion(channels=channels,num_diffusion_steps=num_diffusion_steps,training=True, depth=depth)
optimizer = torch.optim.AdamW(model.parameters(), lr=learnRate)
epochs = 3 #@param


# Training loop for diffusion model
- create loop for training diffusion model
- ouput trained weights to be used for further training

In [19]:
for data, stepAhead in data_loader:    # Process your data
    # data will be a batch of your tensors
    # stepAhead will be the ground truth tensor
    # Train your model
    data = data[0]
    stepAhead = stepAhead[0]
    print(f"Data shape: {data.shape}")
    print(f"StepAhead shape: {stepAhead.shape}")
    for _ in range(epochs):
        optimizer.zero_grad()
        y, loss = model(data, stepAhead)
        print(f"Loss: {loss.item()}")
        loss.backward()
        optimizer.step()


Data shape: torch.Size([9, 1079, 64, 384])
StepAhead shape: torch.Size([9, 1079, 64, 384])
Loss: 1.0224237442016602
Loss: 1.018386721611023
Loss: 1.0030959844589233
Data shape: torch.Size([9, 1079, 64, 384])
StepAhead shape: torch.Size([9, 1079, 64, 384])
Loss: 1.0045024156570435
Loss: 0.9985254406929016
Loss: 0.9880445599555969
Data shape: torch.Size([9, 1079, 64, 384])
StepAhead shape: torch.Size([9, 1079, 64, 384])
Loss: 0.9712555408477783
Loss: 0.9461786150932312
Loss: 0.9121125340461731
Data shape: torch.Size([9, 1079, 64, 384])
StepAhead shape: torch.Size([9, 1079, 64, 384])
Loss: 0.8726099133491516
Loss: 0.8312634229660034
Loss: 0.7910821437835693
