In [None]:
import torch 

data = torch.rand(64,200,3)

my_layer = torch.nn.Linear(3,1)

output = my_layer(data)

print(output.shape)

: 

In [1]:
import torch
weights = torch.load("save/humanml_trans_enc_512/model000200000.pt",map_location=torch.device('cpu'))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import torch.nn as nn
import torchvision.models as models

class ZeroConvBlock(nn.Module):
    def __init__(self, input, output):
        super(ZeroConvBlock, self).__init__()
        self.conv = nn.Conv1d(input, output, kernel_size=1, stride=1, padding=0, bias=False)
        nn.init.constant_(self.conv.weight, 0)
        # nn.init.constant_(self.conv.bias, 0)
        # nn.init.zeros_(self.conv.bias)

    def forward(self, x):
        return self.conv(x)

class ImageEmbedding(nn.Module):
    def __init__(self):
        super(ImageEmbedding, self).__init__()
        # You can replace this with any other suitable architecture
        cachedResnet = models.resnet18(pretrained=True)
        for param in cachedResnet.parameters():
            param.requires_grad = False 
        cachedResnet.fc = nn.Identity()
        self.cnn = cachedResnet
        # self.cnn.fc = nn.Linear(self.cnn.fc.in_features, output_dim)
        
    def forward(self, x):
        with torch.no_grad(): # We don't need this remove it
            return self.cnn(x)


class ModifiedTransformerEncoder(nn.Module):
    def __init__(self, num_layers, d_model ,nhead ,dim_feedforward ,dropout ,activation
):
        super(ModifiedTransformerEncoder, self).__init__()  
        self.nheads = nhead
        self.d_model = d_model
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.activation = activation
        
        self.imageEmbedding = ImageEmbedding()
        

        self.inputConv = ZeroConvBlock(d_model, d_model)

        # self.inputConv = ZeroConvBlock(image_condition.shape()[-1],self.d_model)

        self.originalLayers = nn.ModuleList([nn.TransformerEncoderLayer(d_model=self.d_model,
                                                              nhead=self.nheads,
                                                              dim_feedforward=self.dim_feedforward,
                                                              dropout=self.dropout,
                                                              activation=self.activation) for _ in range(num_layers)])
        self.trainableLayers = nn.ModuleList([nn.TransformerEncoderLayer(d_model=self.d_model,
                                                              nhead=self.nheads,
                                                              dim_feedforward=self.dim_feedforward,
                                                              dropout=self.dropout,
                                                              activation=self.activation) for _ in range(num_layers)])
        
        self.zeroConvLayers = nn.ModuleList([ZeroConvBlock(self.d_model, self.d_model) for _ in range(num_layers)])

        # Check if GPU is available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Move the model to the chosen device
        self.to(self.device)  

        # Set requires_grad to False for the parameters of the original layers
        for layer in self.originalLayers:
            for param in layer.parameters():
                param.requires_grad = False

    def loadCondition(self, condition):
        self.image_condition = condition.to(self.device)
        # self.image_condition = condition

    def forward(self, x):
        # Initial processing of the condition
        x = x.to(self.device)
        condition_embedding = self.imageEmbedding(self.image_condition)
        
        condition_embedding = condition_embedding.view(1, -1).repeat(x.size(0), 1).view(x.size(0), x.size(1), -1)
        # Apply the ZeroConvBlock
        condition_embedding = self.inputConv(condition_embedding.permute(1, 2, 0)).permute(2, 0, 1)

        trainableOutput = x + condition_embedding
        
        originalOutput = x
        
        for i in range(len(self.trainableLayers)):
            originalLayer = self.originalLayers[i]
            trainableLayer = self.trainableLayers[i]
            convBlock = self.zeroConvLayers[i]

            originalIntermediate = originalLayer(originalOutput)
            
            trainableOutput = trainableLayer(trainableOutput)
            
            convOutput = convBlock(trainableOutput.permute(1, 2, 0)).permute(2, 0, 1)
            
            originalOutput = originalIntermediate + convOutput
        
        return originalOutput
    
    def load_original_weights(self, state_dict):
        
        # Iterate over each layer in originalLayers and load the corresponding weights
        for i, layer in enumerate(self.originalLayers):
            # Construct the keys for the encoder layer's parameters
            layer_state_dict = {k.replace(f'seqTransEncoder.layers.{i}.', ''): v 
                                for k, v in state_dict.items() if f'seqTransEncoder.layers.{i}.' in k}
            layer.load_state_dict(layer_state_dict, strict=True)

In [3]:
controlformer = ModifiedTransformerEncoder(num_layers=8, d_model=512, nhead=4, dim_feedforward=1024, dropout=0.1, activation='gelu')

In [4]:
controlformer.load_original_weights(weights)

In [None]:
condition = torch.randn(64, 3, 480, 480)

In [None]:
controlformer.loadCondition(condition)

In [None]:
input = torch.randn(197, 64, 512)

In [None]:
x = controlformer(input)

In [None]:
x[1:].shape

In [None]:
from torchviz import make_dot

# Forward pass to create the computational graph
# output = model(dummy_input,1)

# Visualize the model
dot = make_dot(output, params=dict(model.named_parameters()))
dot.format = 'png'
dot.render('model_visualization')


In [None]:
import torch
import pickle

with open('model.pkl', 'rb') as f:
    model = pickle.load(f)
# model.eval()  # Set the model to evaluation mode

In [None]:
model

In [None]:
dummy_input = torch.randn(2, 22,263,90)

In [None]:
layer = model.input_process

In [None]:
from torchinfo import summary

In [None]:
output = layer(dummy_input)
dot = make_dot(output, params=dict(layer.named_parameters()))
dot.format = 'png'
dot.render('model_visualization')

In [None]:
from torchviz import make_dot

# Assuming your input shape is (1, 3, 32, 32) for an image model
dummy_input = torch.randn(1, 263)
dot = make_dot( (dummy_input))
dot.render("model_architecture.gv", view=False)  # Save as a Graphviz file
