In [2]:
import sys
import os

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset

# from sam2.build_sam import build_sam2

sys.path.append(os.path.abspath('/home/dmatveev/workdir/rosneft_segmentation/experiments'))


from prepare_data import SeismicDataset
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
config_2d = {
        "type": "2D",
        "seismic_dir": "/home/dmatveev/workdir/rosneft_segmentation/data/Salt2d/seismic",
        "label_dir": "/home/dmatveev/workdir/rosneft_segmentation/data/Salt2d/label",
        "shape": (224, 224),
        "mask_dtype": np.uint8,
        "augmentation_pipeline": None
    }

In [3]:
seismic_dataset = SeismicDataset(config_2d)

In [27]:
# for idx in range(10):
#     sample = seismic_dataset[idx]

#     fig, axes = plt.subplots(1, 2, figsize=(12, 6))

#     axes[0].imshow(sample["seismic_img"][:,:,0],cmap='gray')
#     axes[0].axis("off")

#     axes[1].imshow(sample["label"], cmap="gray")
#     axes[1].axis("off")

#     plt.tight_layout()
#     plt.show()



In [29]:
def __getitem__(self, idx):
    sample = self.dataset[idx]
    # Convert image (HWC) to tensor (CHW)
    seismic_img = torch.from_numpy(sample["seismic_img"]).float().permute(2, 0, 1)
    label = torch.from_numpy(sample["label"]).long()
    out = {"seismic_img": seismic_img, "label": label}
    
    # Always add mask_prompt if exists or set to None
    if "mask_prompt" in sample:
        if sample["mask_prompt"] is not None:
            out["mask_prompt"] = torch.from_numpy(sample["mask_prompt"]).long()
        else:
            out["mask_prompt"] = None

    # Always add point_prompt, even if it's None
    point_prompt = sample.get("point_prompt")
    if point_prompt is not None:
        out["point_prompt"] = torch.from_numpy(point_prompt).float()
    else:
        out["point_prompt"] = None

    return out


In [30]:
train_config = {
    "model": {
         "checkpoint": "/home/dmatveev/workdir/rosneft_segmentation/models/sam2.1_hiera_base_plus.pt",
         "config": "configs/sam2.1/sam2.1_hiera_b+.yaml",
         "device": "cuda",
         "use_ia3_adapters": True,
         "ia3_adapter_size": 64,  # example adapter size parameter
         "freeze_base": True,
    },
    "loss": {
         "type": "CrossEntropyLoss",  # or "MSELoss", etc.
         "params": {}
    },
    "training": {
         "epochs": 10,
         "batch_size": 4,
         "lr": 1e-4,
         "use_mask": True,  # whether to use mask prompts during training
         "num_workers": 4,
         "log_interval": 10,
    },
    "clearml": {
         "project_name": "SAM2 Fine Tuning",
         "task_name": "IA3 Adapter Training",
    }
}

In [6]:
sam2_checkpoint = train_config["model"]["checkpoint"]
model_cfg = train_config["model"]["config"]
device = train_config["model"]["device"]


In [7]:
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
sam2_model.to(device)

SAM2Base(
  (image_encoder): ImageEncoder(
    (trunk): Hiera(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 112, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
      )
      (blocks): ModuleList(
        (0-1): 2 x MultiScaleBlock(
          (norm1): LayerNorm((112,), eps=1e-06, elementwise_affine=True)
          (attn): MultiScaleAttention(
            (qkv): Linear(in_features=112, out_features=336, bias=True)
            (proj): Linear(in_features=112, out_features=112, bias=True)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((112,), eps=1e-06, elementwise_affine=True)
          (mlp): MLP(
            (layers): ModuleList(
              (0): Linear(in_features=112, out_features=448, bias=True)
              (1): Linear(in_features=448, out_features=112, bias=True)
            )
            (act): GELU(approximate='none')
          )
        )
        (2): MultiScaleBlock(
          (norm1): LayerNorm((112,), eps=1e-06, elementwise_affi

In [9]:
if train_config["model"]["freeze_base"]:
    for param in sam2_model.parameters():
        param.requires_grad = False

In [12]:
class IA3Layer(nn.Module):
    """
    Оборачивает существующий линейный слой (nn.Linear) и добавляет к его выходу
    поэлементное масштабирование. Исходный линейный слой замораживается.
    """
    def __init__(self, module: nn.Linear):
        super().__init__()
        self.module = module  # предобученный (замороженный) слой
        # Инициализируем масштабный вектор единицами (по числу выходных признаков)
        self.scale = nn.Parameter(torch.ones(module.out_features, device=module.weight.device))
    
    def forward(self, input):
        # Вычисляем обычный выход слоя
        out = self.module(input)
        # Применяем поэлементное масштабирование: каждый выходной канал умножается на свой scale
        return out * self.scale

In [13]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
sam2_model.to(device)

# Замораживаем все параметры базовой модели
for param in sam2_model.parameters():
    param.requires_grad = False

In [14]:
num_ia3 = 0
for block in sam2_model.image_encoder.trunk.blocks:
    # Если блок имеет модуль внимания, оборачиваем его проекции
    if hasattr(block, "attn"):
        # Оборачиваем линейный слой qkv
        if hasattr(block.attn, "qkv") and isinstance(block.attn.qkv, nn.Linear):
            block.attn.qkv = IA3Layer(block.attn.qkv).to(device)
            num_ia3 += 1
        # Оборачиваем линейный слой proj
        if hasattr(block.attn, "proj") and isinstance(block.attn.proj, nn.Linear):
            block.attn.proj = IA3Layer(block.attn.proj).to(device)
            num_ia3 += 1
    # Если блок содержит MLP, оборачиваем каждый линейный слой внутри него
    if hasattr(block, "mlp") and hasattr(block.mlp, "layers"):
        for i, layer in enumerate(block.mlp.layers):
            if isinstance(layer, nn.Linear):
                block.mlp.layers[i] = IA3Layer(layer).to(device)
                num_ia3 += 1

In [15]:
print("Number of IA³ adapters inserted:", num_ia3)

Number of IA³ adapters inserted: 96


In [16]:
# После интеграции оставляем для обучения только параметры адаптеров (scale)
for name, param in sam2_model.named_parameters():
    if "scale" not in name:
        param.requires_grad = False

In [17]:
trainable_params = sum(p.numel() for p in sam2_model.parameters() if p.requires_grad)
print("Обучаемых параметров:", trainable_params)


Обучаемых параметров: 96768


In [17]:
for name, param in sam2_model.named_parameters():
    if param.requires_grad:
        print(f"{name}: {param.numel()} параметров")


image_encoder.trunk.blocks.0.attn.qkv.scale: 336 параметров
image_encoder.trunk.blocks.0.attn.proj.scale: 112 параметров
image_encoder.trunk.blocks.0.mlp.layers.0.scale: 448 параметров
image_encoder.trunk.blocks.0.mlp.layers.1.scale: 112 параметров
image_encoder.trunk.blocks.1.attn.qkv.scale: 336 параметров
image_encoder.trunk.blocks.1.attn.proj.scale: 112 параметров
image_encoder.trunk.blocks.1.mlp.layers.0.scale: 448 параметров
image_encoder.trunk.blocks.1.mlp.layers.1.scale: 112 параметров
image_encoder.trunk.blocks.2.attn.qkv.scale: 672 параметров
image_encoder.trunk.blocks.2.attn.proj.scale: 224 параметров
image_encoder.trunk.blocks.2.mlp.layers.0.scale: 896 параметров
image_encoder.trunk.blocks.2.mlp.layers.1.scale: 224 параметров
image_encoder.trunk.blocks.3.attn.qkv.scale: 672 параметров
image_encoder.trunk.blocks.3.attn.proj.scale: 224 параметров
image_encoder.trunk.blocks.3.mlp.layers.0.scale: 896 параметров
image_encoder.trunk.blocks.3.mlp.layers.1.scale: 224 параметров
imag

In [19]:
import torch.optim as optim


In [20]:
loss_type = train_config["loss"]["type"]
if loss_type == "CrossEntropyLoss":
    criterion = nn.CrossEntropyLoss(**train_config["loss"]["params"])
elif loss_type == "MSELoss":
    criterion = nn.MSELoss(**train_config["loss"]["params"])
else:
    raise ValueError(f"Unsupported loss type: {loss_type}")

optimizer = optim.Adam(filter(lambda p: p.requires_grad, sam2_model.parameters()),
                       lr=train_config["training"]["lr"])

In [31]:
torch_dataset = TorchSeismicDataset(seismic_dataset)

train_loader = DataLoader(torch_dataset,
                          batch_size=train_config["training"]["batch_size"],
                          shuffle=True,
                          num_workers=train_config["training"]["num_workers"])

In [25]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x7f085d913bd0>

In [22]:
num_epochs = train_config["training"]["epochs"]
log_interval = train_config["training"]["log_interval"]

In [28]:
log_interval

10

In [32]:
for batch_idx, batch in enumerate(train_loader):
    break

In [33]:
batch_idx

0

In [34]:
sam2_model.train()

for epoch in range(num_epochs):
    epoch_loss = 0.0
    for batch_idx, batch in enumerate(train_loader):
        optimizer.zero_grad()
        
        # Move data to the specified device
        seismic_imgs = batch["seismic_img"].to(device)   # [B, C, H, W]
        labels = batch["label"].to(device)                 # [B, H, W]
        
        # Optionally pass prompts if available and enabled
        point_prompts = batch.get("point_prompt")
        if point_prompts is not None:
            point_prompts = point_prompts.to(device)
        mask_prompts = batch.get("mask_prompt")
        if mask_prompts is not None and train_config["training"]["use_mask"]:
            mask_prompts = mask_prompts.to(device)
        
        # Forward pass: adjust arguments as per your model's API
        outputs = sam2_model(seismic_imgs, point_prompt=point_prompts, mask_prompt=mask_prompts)
        
        # Compute loss (assuming outputs shape is [B, num_classes, H, W])
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        if batch_idx % log_interval == 0:
            current_iter = epoch * len(train_loader) + batch_idx
            print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{batch_idx}/{len(train_loader)}] Loss: {loss.item():.4f}")
            # task.get_logger().report_scalar("loss", "train", iteration=current_iter, value=loss.item())
    
    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}] Average Loss: {avg_loss:.4f}")
    # task.get_logger().report_scalar("epoch_loss", "train", iteration=epoch, value=avg_loss)
    
    # Save checkpoint for the epoch
    checkpoint_path = f"checkpoint_epoch_{epoch+1}.pt"
    torch.save({
        "epoch": epoch + 1,
        "model_state_dict": sam2_model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": avg_loss,
    }, checkpoint_path)
    # task.get_logger().report_artifact(name=f"checkpoint_epoch_{epoch+1}", artifact_object=checkpoint_path)


NotImplementedError: Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuningSee notebooks/video_predictor_example.ipynb for an inference example.