In [1]:
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_

class DINOHead(nn.Module):
    def __init__(self, in_chans, out_chans, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
        super().__init__()
        nlayers = max(nlayers, 1)
        self.in_chans = in_chans
        
        if nlayers == 1:
            self.mlp = nn.Conv2d(in_chans, bottleneck_dim, 3, padding=1)
        else:
            layers = [nn.Conv2d(in_chans, hidden_dim, 3, padding=1)]
            if use_bn:
                layers.append(nn.BatchNorm2d(hidden_dim))
            layers.append(nn.GELU())
            for _ in range(nlayers - 2):
                layers.append(nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1))
                if use_bn:
                    layers.append(nn.BatchNorm2d(hidden_dim))
                layers.append(nn.GELU())
            layers.append(nn.Conv2d(hidden_dim, bottleneck_dim, 3, padding=1))
            self.mlp = nn.Sequential(*layers)
        
        self.apply(self._init_weights)
        self.last_layer = nn.utils.weight_norm(nn.Conv2d(bottleneck_dim, out_chans, 1, bias=False))
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False

    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.mlp(x)
        x = nn.functional.normalize(x, dim=1, p=2)  # Normalize along channel dimension
        x = self.last_layer(x)
        return x

In [2]:
# Example usage:
model = DINOHead(in_chans=3, out_chans=3)
input_image = torch.randn(1, 3, 224, 224)
output = model(input_image)

In [22]:
output.shape

torch.Size([1, 3, 224, 224])

In [25]:
output[:1, ...].shape

torch.Size([1, 3, 224, 224])

In [18]:
noise = (torch.rand(4, 3, 224, 224), torch.rand(4, 3, 224, 224))

In [21]:
import torch

# Assuming noise is a tuple of two tensors
noise = (torch.rand(4, 3, 224, 224), torch.rand(4, 3, 224, 224))

# Concatenate the tensors along the batch dimension (dim=0)
concatenated_noise = torch.cat(noise, dim=0)

# Print the shape of the result to confirm
print(concatenated_noise.shape)


torch.Size([8, 3, 224, 224])


In [20]:
noise.shape

AttributeError: 'tuple' object has no attribute 'shape'