In [4]:
import torch

In [2]:
noise = torch.rand(4,6)
ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)

In [5]:
ids_shuffle

tensor([[3, 2, 5, 0, 1, 4],
        [0, 2, 4, 1, 5, 3],
        [2, 0, 5, 1, 3, 4],
        [4, 3, 0, 1, 2, 5]])

In [3]:
ids_restore

tensor([[3, 4, 1, 0, 5, 2],
        [0, 3, 1, 5, 2, 4],
        [1, 3, 0, 4, 5, 2],
        [2, 3, 4, 1, 0, 5]])

In [7]:
ids_shuffle[:, :2]

tensor([[3, 2],
        [0, 2],
        [2, 0],
        [4, 3]])

In [8]:
mask = torch.ones([4, 6])
mask

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])

In [9]:
mask[:, :2] = 0

In [10]:
mask

tensor([[0., 0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.]])

In [11]:
mask = torch.gather(mask, dim=1, index=ids_restore)
mask

tensor([[1., 1., 0., 0., 1., 1.],
        [0., 1., 0., 1., 1., 1.],
        [0., 1., 0., 1., 1., 1.],
        [1., 1., 1., 0., 0., 1.]])

In [2]:
import torch.nn as nn

In [3]:
k = nn.Linear(128, 128)

In [5]:
k(torch.rand(2,4,2,128)).shape

torch.Size([2, 4, 2, 128])

In [17]:
from timm.models.vision_transformer import PatchEmbed, Block
from Model.Module.asymAttention import AsymAttention

In [12]:
b = Block(128, 4, qkv_bias=True)

In [15]:
b.attn = AsymAttention(
    dim=128,
    num_heads=4,
    qkv_bias=True,
    attn_drop=0.,
    proj_drop=0.,
)

In [18]:
p = PatchEmbed()

In [19]:
p

PatchEmbed(
  (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (norm): Identity()
)

In [20]:
b

Block(
  (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (attn): AsymAttention(
    (q): Linear(in_features=128, out_features=128, bias=True)
    (kv): Linear(in_features=128, out_features=256, bias=True)
    (attn_drop): Dropout(p=0.0, inplace=False)
    (proj): Linear(in_features=128, out_features=128, bias=True)
    (proj_drop): Dropout(p=0.0, inplace=False)
  )
  (ls1): Identity()
  (drop_path1): Identity()
  (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (mlp): Mlp(
    (fc1): Linear(in_features=128, out_features=512, bias=True)
    (act): GELU(approximate='none')
    (drop1): Dropout(p=0.0, inplace=False)
    (norm): Identity()
    (fc2): Linear(in_features=512, out_features=128, bias=True)
    (drop2): Dropout(p=0.0, inplace=False)
  )
  (ls2): Identity()
  (drop_path2): Identity()
)

In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("sautkin/imagenet1k3")

print("Path to dataset files:", path)

  from .autonotebook import tqdm as notebook_tqdm


Downloading from https://www.kaggle.com/api/v1/datasets/download/sautkin/imagenet1k3?dataset_version_number=2...


100%|██████████| 11.5G/11.5G [01:20<00:00, 154MB/s] 

Extracting files...





Path to dataset files: /home/sagemaker-user/.cache/kagglehub/datasets/sautkin/imagenet1k3/versions/2


In [2]:
import os
import random
from pathlib import Path

root = Path("Data/imagenet")  # change to your dataset root folder
train_file = "train.txt"
val_file = "val.txt"
val_ratio = 0.1  # 10% validation

# Get all class subfolders (sorted alphabetically)
class_dirs = sorted([d for d in root.iterdir() if d.is_dir()])

samples = []

for class_idx, class_dir in enumerate(class_dirs):
    for img_path in class_dir.rglob("*"):
        if img_path.is_file():

            # Include parent folder name in the relative path
            # e.g., dataset/00500/img001.jpg
            rel_path = img_path.relative_to(root.parent)

            # Store (path, numeric_label)
            samples.append((str(rel_path), class_idx))

# Shuffle
random.shuffle(samples)

# Split train / val
val_size = int(len(samples) * val_ratio)
val_samples = samples[:val_size]
train_samples = samples[val_size:]

# Save train.txt
with open(train_file, "w") as f:
    for path, label in train_samples:
        f.write(f"{path} {label}\n")

# Save val.txt
with open(val_file, "w") as f:
    for path, label in val_samples:
        f.write(f"{path} {label}\n")

print("Done.")
print(f"Total images: {len(samples)}")
print(f"Train samples: {len(train_samples)}, Val samples: {len(val_samples)}")


Done.
Total images: 326025
Train samples: 293423, Val samples: 32602
