In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
cd drive/MyDrive/DFPIR_project/DFPIR-main/

/content/drive/MyDrive/DFPIR_project/DFPIR-main


In [3]:
%%writefile requirements.txt
torch>=1.8.0
torchvision
openai-clip #
numpy #
scikit-image #
scikit-video #
scipy #
matplotlib #
einops #
huggingface-hub #
tqdm #
tensorboard #

Overwriting requirements.txt


In [4]:
!pip install -r requirements.txt

Collecting openai-clip (from -r requirements.txt (line 3))
  Downloading openai-clip-1.0.1.tar.gz (1.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m67.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting scikit-video (from -r requirements.txt (line 6))
  Downloading scikit_video-1.1.11-py2.py3-none-any.whl.metadata (1.1 kB)
Collecting ftfy (from openai-clip->-r requirements.txt (line 3))
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading scikit_video-1.1.11-py2.py3-none-any.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m105.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: openai-clip
  Building wheel for openai-clip (setup.py) ...

#Modify

In [12]:
%%writefile net/model.py
## PromptIR: Prompting for All-in-One Blind Image Restoration
## Vaishnav Potlapalli, Syed Waqas Zamir, Salman Khan, and Fahad Shahbaz Khan
## https://arxiv.org/abs/2306.13090


import torch, torchvision
# print(torch.__version__)
import torch.nn as nn
import torch.nn.functional as F
from pdb import set_trace as stx
import numbers

from einops import rearrange
from einops.layers.torch import Rearrange
import time

import clip
import os
import sys
from net.arch_util import LayerNorm2d
from net.local_arch import Local_Base
import torchvision.transforms as transforms
from huggingface_hub import PyTorchModelHubMixin

##########################################################################
## Layer Norm

def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')

def to_4d(x,h,w):
    return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)

class BiasFree_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(BiasFree_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return x / torch.sqrt(sigma+1e-5) * self.weight

class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(WithBias_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias

class LayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type):
        super(LayerNorm, self).__init__()
        if LayerNorm_type =='BiasFree':
            self.body = BiasFree_LayerNorm(dim)
        else:
            self.body = WithBias_LayerNorm(dim)

    def forward(self, x):
        h, w = x.shape[-2:]
        return to_4d(self.body(to_3d(x)), h, w)

##########################################################################
## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()

        hidden_features = int(dim*ffn_expansion_factor)

        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)

        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)

        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x

##########################################################################
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
    def __init__(self, dim, num_heads, bias):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
    def forward(self, x):
        b,c,h,w = x.shape
        qkv = self.qkv_dwconv(self.qkv(x))
        q,k,v = qkv.chunk(3, dim=1)
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)
        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)
        out = (attn @ v)
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out

class resblock(nn.Module):
    def __init__(self, dim):

        super(resblock, self).__init__()
        # self.norm = LayerNorm(dim, LayerNorm_type='BiasFree')

        self.body = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PReLU(),
                                  nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False))
    def forward(self, x):
        res = self.body((x))
        res += x
        return res


##########################################################################
## Resizing modules
class Downsample(nn.Module):
    def __init__(self, n_feat):
        super(Downsample, self).__init__()

        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelUnshuffle(2))
    def forward(self, x):
        return self.body(x)

class Upsample(nn.Module):
    def __init__(self, n_feat):
        super(Upsample, self).__init__()

        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelShuffle(2))
    def forward(self, x):
        return self.body(x)

##########################################################################
## Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
        super(TransformerBlock, self).__init__()
        self.norm1 = LayerNorm(dim, LayerNorm_type)
        self.attn = Attention(dim, num_heads, bias)
        self.norm2 = LayerNorm(dim, LayerNorm_type)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

##########################################################################
## Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
    def __init__(self, in_c=3, embed_dim=48, bias=False):
        super(OverlapPatchEmbed, self).__init__()
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
    def forward(self, x):
        x = self.proj(x)
        return x

##########################################################################
##---------- Prompt Gen Module -----------------------
class PromptGenBlock(nn.Module):
    def __init__(self,prompt_dim=128,prompt_len=5,prompt_size = 96,lin_dim = 192):
        super(PromptGenBlock,self).__init__()
        self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size))
        self.linear_layer = nn.Linear(lin_dim,prompt_len)
        self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False)
    def forward(self,x):
        B,C,H,W = x.shape
        emb = x.mean(dim=(-2,-1))
        prompt_weights = F.softmax(self.linear_layer(emb),dim=1)
        prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)
        prompt = torch.sum(prompt,dim=1)
        prompt = F.interpolate(prompt,(H,W),mode="bilinear")
        prompt = self.conv3x3(prompt)
        return prompt





##########################################################################
##---------- PromptIR -----------------------

class PromptIR(nn.Module):
    def __init__(self,
        inp_channels=3,
        out_channels=3,
        dim = 48,
        num_blocks = [4,6,6,8],
        num_refinement_blocks = 4,
        heads = [1,2,4,8],
        ffn_expansion_factor = 2.66,
        bias = False,
        LayerNorm_type = 'WithBias',   ## Other option 'BiasFree'
        decoder = False,
    ):
        super(PromptIR, self).__init__()
        self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
        self.decoder = decoder
        if self.decoder:
            self.prompt1 = PromptGenBlock(prompt_dim=64,prompt_len=5,prompt_size = 64,lin_dim = 96)
            self.prompt2 = PromptGenBlock(prompt_dim=128,prompt_len=5,prompt_size = 32,lin_dim = 192)
            self.prompt3 = PromptGenBlock(prompt_dim=320,prompt_len=5,prompt_size = 16,lin_dim = 384)
        self.chnl_reduce1 = nn.Conv2d(64,64,kernel_size=1,bias=bias)
        self.chnl_reduce2 = nn.Conv2d(128,128,kernel_size=1,bias=bias)
        self.chnl_reduce3 = nn.Conv2d(320,256,kernel_size=1,bias=bias)

        self.reduce_noise_channel_1 = nn.Conv2d(dim + 64,dim,kernel_size=1,bias=bias)
        self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])

        self.down1_2 = Downsample(dim) ## From Level 1 to Level 2

        self.reduce_noise_channel_2 = nn.Conv2d(int(dim*2**1) + 128,int(dim*2**1),kernel_size=1,bias=bias)
        self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])

        self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3

        self.reduce_noise_channel_3 = nn.Conv2d(int(dim*2**2) + 256,int(dim*2**2),kernel_size=1,bias=bias)
        self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])

        self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
        self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])

        self.up4_3 = Upsample(int(dim*2**2)) ## From Level 4 to Level 3
        self.reduce_chan_level3 = nn.Conv2d(int(dim*2**1)+192, int(dim*2**2), kernel_size=1, bias=bias)
        self.noise_level3 = TransformerBlock(dim=int(dim*2**2) + 512, num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)
        self.reduce_noise_level3 = nn.Conv2d(int(dim*2**2)+512,int(dim*2**2),kernel_size=1,bias=bias)


        self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])


        self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
        self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
        self.noise_level2 = TransformerBlock(dim=int(dim*2**1) + 224, num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)
        self.reduce_noise_level2 = nn.Conv2d(int(dim*2**1)+224,int(dim*2**2),kernel_size=1,bias=bias)

        self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])

        self.up2_1 = Upsample(int(dim*2**1))  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)

        self.noise_level1 = TransformerBlock(dim=int(dim*2**1)+64, num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)
        self.reduce_noise_level1 = nn.Conv2d(int(dim*2**1)+64,int(dim*2**1),kernel_size=1,bias=bias)

        self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])

        self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])

        self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, inp_img,noise_emb = None):
        inp_enc_level1 = self.patch_embed(inp_img)
        out_enc_level1 = self.encoder_level1(inp_enc_level1)
        inp_enc_level2 = self.down1_2(out_enc_level1)
        out_enc_level2 = self.encoder_level2(inp_enc_level2)
        inp_enc_level3 = self.down2_3(out_enc_level2)
        out_enc_level3 = self.encoder_level3(inp_enc_level3)
        inp_enc_level4 = self.down3_4(out_enc_level3)
        latent = self.latent(inp_enc_level4)
        if self.decoder:
            dec3_param = self.prompt3(latent)
            latent = torch.cat([latent, dec3_param], 1)
            latent = self.noise_level3(latent)
            latent = self.reduce_noise_level3(latent)

        inp_dec_level3 = self.up4_3(latent)

        inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)

        out_dec_level3 = self.decoder_level3(inp_dec_level3)
        if self.decoder:
            dec2_param = self.prompt2(out_dec_level3)
            out_dec_level3 = torch.cat([out_dec_level3, dec2_param], 1)
            out_dec_level3 = self.noise_level2(out_dec_level3)
            out_dec_level3 = self.reduce_noise_level2(out_dec_level3)

        inp_dec_level2 = self.up3_2(out_dec_level3)
        inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)

        out_dec_level2 = self.decoder_level2(inp_dec_level2)
        if self.decoder:
            dec1_param = self.prompt1(out_dec_level2)
            out_dec_level2 = torch.cat([out_dec_level2, dec1_param], 1)
            out_dec_level2 = self.noise_level1(out_dec_level2)
            out_dec_level2 = self.reduce_noise_level1(out_dec_level2)

        inp_dec_level1 = self.up2_1(out_dec_level2)
        inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
        out_dec_level1 = self.decoder_level1(inp_dec_level1)
        out_dec_level1 = self.refinement(out_dec_level1)
        out_dec_level1 = self.output(out_dec_level1) + inp_img
        return out_dec_level1


class ch_shuffle_high_text(nn.Module):
    def __init__(self, ch_dim,num_heads,LayerNorm_type,ffn_expansion_factor, bias,lin_ch=512):
        super(ch_shuffle_high_text, self).__init__()
        self.dim = ch_dim
        self.linear_layer1 = nn.Linear(lin_ch,lin_ch)
        self.linear_layer3 = nn.Linear(lin_ch,2*ch_dim)
# -----------------------------------------------------------------------
        self.conv1x1   = nn.Conv2d(ch_dim, 2*ch_dim, kernel_size=1, stride=1, padding=0) #
        self.conv_out   = nn.Conv2d(2*ch_dim, ch_dim, kernel_size=1, stride=1, padding=0) #
        self.norm1 = LayerNorm(ch_dim, LayerNorm_type)
        self.norm2 = LayerNorm(ch_dim, LayerNorm_type)
        self.norm3 = LayerNorm(ch_dim, LayerNorm_type)
        self.select_attn = Topm_CrossAttention_Restormer(ch_dim, num_heads, bias=False)
        self.ffn = FeedForward(ch_dim, ffn_expansion_factor, bias)
    def forward(self, img_featur, text_code):
        b,c,_,_ = img_featur.shape
        img_feature2 = img_featur
        deg_prompt = text_code
        text_code = self.linear_layer1(text_code)
        text_code = self.linear_layer3(text_code)
        soft_values, soft_indices = torch.topk(text_code, k=2*self.dim)
        img_featur = self.conv1x1(img_featur)
        shuffled_img = img_featur[torch.arange(b).unsqueeze(1), soft_indices, :, :] # shuffle

        q = self.conv_out(shuffled_img)
        att = self.select_attn(self.norm1(q),self.norm2(img_feature2),deg_prompt)
        output = att + self.ffn(self.norm3(att))
        return output,img_feature2
import torch
import torch.nn as nn
import torch.nn.functional as F

class AdvancedLearnableMaskGenerator(nn.Module):
    """
    Generates a highly context-aware C x C mask.
    It uses the original feature, the shuffled feature, and the degradation prompt as inputs.
    """
    def __init__(self, channel_dim, prompt_embedding_dim=512, hidden_dim=128):
        super().__init__()
        self.channel_dim = channel_dim

        # This processor now takes the concatenated vectors from BOTH feature maps (original + shuffled).
        self.feature_processor = nn.Linear(channel_dim * 2, hidden_dim)

        # This processor handles the degradation prompt as before.
        self.prompt_processor = nn.Linear(prompt_embedding_dim, hidden_dim)

        # The final layers that generate the mask from all combined information.
        self.mask_generator = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, channel_dim * channel_dim)
        )

    def forward(self, original_feature, shuffled_feature, degradation_prompt):
        """
        Args:
            original_feature (torch.Tensor): The original feature map from the encoder (B, C, H, W).
            shuffled_feature (torch.Tensor): The shuffled feature map from the DGCPM (B, C, H, W).
            degradation_prompt (torch.Tensor): The CLIP text embedding for the task (B, prompt_embedding_dim).

        Returns:
            torch.Tensor: A learnable soft mask of shape (B, C, C).
        """
        # 1. Condense both original and shuffled features into descriptor vectors.
        original_vector = F.adaptive_avg_pool2d(original_feature, 1).squeeze(-1).squeeze(-1)
        shuffled_vector = F.adaptive_avg_pool2d(shuffled_feature, 1).squeeze(-1).squeeze(-1)

        # 2. Concatenate the feature vectors to capture their relationship.
        combined_feature_vector = torch.cat([original_vector, shuffled_vector], dim=1)

        # 3. Process the combined features and the degradation prompt separately.
        processed_features = self.feature_processor(combined_feature_vector)
        processed_prompt = self.prompt_processor(degradation_prompt)

        # 4. Concatenate all processed information.
        final_combined_info = torch.cat([processed_features, processed_prompt], dim=1)

        # 5. Generate the final mask values.
        mask_values = self.mask_generator(final_combined_info)

        # 6. Reshape and apply sigmoid to create the soft mask.
        mask = mask_values.view(-1, self.channel_dim, self.channel_dim)
        mask = torch.sigmoid(mask)

        return mask

class Topm_CrossAttention_Restormer(nn.Module):
    """
    Modified CAAPM using the AdvancedLearnableMaskGenerator.
    The __init__ signature is compatible with your main script's call.
    """
    def __init__(self, dim, num_heads, prompt_embedding_dim=512, bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        # Projections for Query, Key, Value
        self.query_proj = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        self.key_proj = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        self.value_proj = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        self.output_proj = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

        # --- MODIFICATION 1: Instantiate the new, more advanced generator ---
        self.mask_generator = AdvancedLearnableMaskGenerator(
            channel_dim=dim // num_heads,
            prompt_embedding_dim=prompt_embedding_dim
        )

    def forward(self, shuffled_feature, original_feature, degradation_prompt):
        b, c, h, w = original_feature.shape
        head_dim = c // self.num_heads

        # The query is derived from the shuffled, task-aligned feature (Q).[1]
        query = self.query_proj(shuffled_feature).view(b, self.num_heads, head_dim, h * w)
        # The key and value are from the original, context-rich feature (F_n).[1]
        key = self.key_proj(original_feature).view(b, self.num_heads, head_dim, h * w)
        value = self.value_proj(original_feature).view(b, self.num_heads, head_dim, h * w)

        query = F.normalize(query, dim=-1)
        key = F.normalize(key, dim=-1)
        raw_attn_scores = (query @ key.transpose(-2, -1)) * self.temperature

        # --- MODIFICATION 2: Prepare inputs and generate the mask for each head ---
        original_feature_reshaped = original_feature.view(b, self.num_heads, head_dim, h, w)
        shuffled_feature_reshaped = shuffled_feature.view(b, self.num_heads, head_dim, h, w)

        masks =[]
        for i in range(self.num_heads):
            # Pass the feature slice for the current head from BOTH original and shuffled maps.
            mask = self.mask_generator(
                original_feature_reshaped[:, i,...],
                shuffled_feature_reshaped[:, i,...],
                degradation_prompt
            )
            masks.append(mask.unsqueeze(1))

        learnable_mask = torch.cat(masks, dim=1)

        # Apply the new, more powerful mask
        preferential_scores = raw_attn_scores * learnable_mask
        attention_probs = F.softmax(preferential_scores, dim=-1)

        output = (attention_probs @ value)
        output = output.view(b, c, h, w)
        output = self.output_proj(output)
        return output

class ChannelShuffle_skip_textguaid(nn.Module):
    def __init__(self,
        inp_channels=3,
        out_channels=3,
        dim = 48,
        num_blocks = [4,6,6,8],
        num_refinement_blocks = 4,
        heads = [1,2,4,8],
        ffn_expansion_factor = 2.66,
        bias = False,
        LayerNorm_type = 'WithBias',   ## Other option 'BiasFree'
        device = "cuda:1",
        # decoder = False,
        dual_pixel_task = False        ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
    ):
        super(ChannelShuffle_skip_textguaid, self).__init__()
        self.device = device

        self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
        self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        self.encoder_shuffle_channel1 = ch_shuffle_high_text(ch_dim = dim,num_heads=heads[0],LayerNorm_type=LayerNorm_type,ffn_expansion_factor=ffn_expansion_factor,bias=bias) # encoder level1 shuffle

        self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
        self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        self.encoder_shuffle_channel2 = ch_shuffle_high_text(ch_dim = int(dim*2**1),num_heads=heads[1],LayerNorm_type=LayerNorm_type,ffn_expansion_factor=ffn_expansion_factor,bias=bias) # encoder level2 shuffle

        self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
        self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
        self.encoder_shuffle_channel3 = ch_shuffle_high_text(ch_dim = int(dim*2**2),num_heads=heads[2],LayerNorm_type=LayerNorm_type,ffn_expansion_factor=ffn_expansion_factor,bias=bias) # encoder level3 shuffle

        self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
        self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
        self.latent_shuffle_channel = ch_shuffle_high_text(ch_dim = int(dim*2**3),num_heads=heads[3],LayerNorm_type=LayerNorm_type,ffn_expansion_factor=ffn_expansion_factor,bias=bias) # latent latent shuffle

        self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
        self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
        self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])

        self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
        self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
        self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])

        self.up2_1 = Upsample(int(dim*2**1))  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)
        self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])

        self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
        #### For Dual-Pixel Defocus Deblurring Task ####
        # self.dual_pixel_task = dual_pixel_task
        # if self.dual_pixel_task:
        #     self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
        # ###########################

        self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, inp_img ,text_code): # ,text_code
        # text_code = torch.randn(1,512).to(self.device) # 这个在测试模型参数量和计算量时候加上
        inp_enc_level1 = self.patch_embed(inp_img) # ch 3-->dim:48
        out_enc_level1 = self.encoder_level1(inp_enc_level1) # ch dim:48-->dim:48

        inp_enc_level2 = self.down1_2(out_enc_level1) # ch dim:48-->dim*2:96
        out_enc_level2 = self.encoder_level2(inp_enc_level2) # ch dim*2:96-->dim*2:96

        inp_enc_level3 = self.down2_3(out_enc_level2) # ch dim*2:96-->dim*2*2:192
        out_enc_level3 = self.encoder_level3(inp_enc_level3) # ch dim*2*2:192-->dim*2*2:192

        inp_enc_level4 = self.down3_4(out_enc_level3)
        latent = self.latent(inp_enc_level4)
        latent,_ = self.latent_shuffle_channel(latent,text_code) # latent latent shuffle

        inp_dec_level3 = self.up4_3(latent)
        outt1,_ = self.encoder_shuffle_channel3(out_enc_level3,text_code)
        inp_dec_level3 = torch.cat([inp_dec_level3, outt1], 1)
        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
        out_dec_level3 = self.decoder_level3(inp_dec_level3)

        inp_dec_level2 = self.up3_2(out_dec_level3)
        outt2,_ = self.encoder_shuffle_channel2(out_enc_level2,text_code)
        inp_dec_level2 = torch.cat([inp_dec_level2, outt2], 1)
        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
        out_dec_level2 = self.decoder_level2(inp_dec_level2)

        inp_dec_level1 = self.up2_1(out_dec_level2)
        outt3,_ = self.encoder_shuffle_channel1(out_enc_level1,text_code)
        inp_dec_level1 = torch.cat([inp_dec_level1, outt3], 1)
        out_dec_level1 = self.decoder_level1(inp_dec_level1)

        out_dec_level1 = self.refinement(out_dec_level1)

        #### For Dual-Pixel Defocus Deblurring Task ####
        # if self.dual_pixel_task:
        #     out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
        #     out_dec_level1 = self.output(out_dec_level1)
        # ###########################
        # else:
        out_dec_level1 = self.output(out_dec_level1) + inp_img

        return out_dec_level1



##---------- Restormer -----------------------
class Restormer(nn.Module):
    def __init__(self,
        inp_channels=3,
        out_channels=3,
        dim = 48,
        num_blocks = [4,6,6,8],
        num_refinement_blocks = 4,
        heads = [1,2,4,8],
        ffn_expansion_factor = 2.66,
        bias = False,
        LayerNorm_type = 'WithBias',   ## Other option 'BiasFree'
        dual_pixel_task = False        ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
    ):

        super(Restormer, self).__init__()

        self.patch_embed = OverlapPatchEmbed(inp_channels, dim)

        self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])

        self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
        self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])

        self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
        self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])

        self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
        self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])

        self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
        self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
        self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])


        self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
        self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
        self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])

        self.up2_1 = Upsample(int(dim*2**1))  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)

        self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])

        self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])

        #### For Dual-Pixel Defocus Deblurring Task ####
        self.dual_pixel_task = dual_pixel_task
        if self.dual_pixel_task:
            self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
        ###########################

        self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, inp_img):

        inp_enc_level1 = self.patch_embed(inp_img)
        out_enc_level1 = self.encoder_level1(inp_enc_level1)

        inp_enc_level2 = self.down1_2(out_enc_level1)
        out_enc_level2 = self.encoder_level2(inp_enc_level2)

        inp_enc_level3 = self.down2_3(out_enc_level2)
        out_enc_level3 = self.encoder_level3(inp_enc_level3)

        inp_enc_level4 = self.down3_4(out_enc_level3)
        latent = self.latent(inp_enc_level4)

        inp_dec_level3 = self.up4_3(latent)
        inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
        out_dec_level3 = self.decoder_level3(inp_dec_level3)

        inp_dec_level2 = self.up3_2(out_dec_level3)
        inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
        out_dec_level2 = self.decoder_level2(inp_dec_level2)

        inp_dec_level1 = self.up2_1(out_dec_level2)
        inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
        out_dec_level1 = self.decoder_level1(inp_dec_level1)

        out_dec_level1 = self.refinement(out_dec_level1)

        #### For Dual-Pixel Defocus Deblurring Task ####
        if self.dual_pixel_task:
            out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
            out_dec_level1 = self.output(out_dec_level1)
        ###########################
        else:
            out_dec_level1 = self.output(out_dec_level1) + inp_img


        return out_dec_level1




Overwriting net/model.py


# Train

In [None]:
!python train_3D_DFPIR.py \
    --epochs 1 \
    --gpu "0" \
    --cuda 0 \
    --batch_size 2 \
    --save_item 20000\
    --save_dir "./checkpoints/" \
    --derain_dir "./train_data/rain/" \
    --dehaze_dir "./train_data/haze/" \
    --denoise_dir "./train_data/noise/" \
    --derain_path "./test/derain/" \
    --dehaze_path "./test/dehaze/" \
    --denoise_path "./test/denoise/"

2025-09-30 20:30:23.607881: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759264223.630191    4893 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759264223.637001    4893 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1759264223.654394    4893 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1759264223.654421    4893 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1759264223.654426    4893 computation_placer.cc:177] computation placer alr

In [None]:
!python test_3D_DFPIR.py \
    --gpu "0" \
    --cuda 0 \
    --pretrained_1 "./DFPIR-3D_p_n34.14-0.9354_31.47-0.8928_28.25-0.8059_p_r38.65-0.9821_p_h31.87-0.9800avr32.88-0.9192.pth.tar" \
    --denoise_path "test/denoise/" \
    --derain_path "test/derain/" \
    --dehaze_path "test/dehaze/" \
    --output_path "output/"

=> loading model './DFPIR-3D_p_n34.14-0.9354_31.47-0.8928_28.25-0.8059_p_r38.65-0.9821_p_h31.87-0.9800avr32.88-0.9192.pth.tar'
Strart test 3D
Start bsd68/ testing Sigma=15...
100%|[35m███████████████████████████████████████████[0m| 68/68 [00:39<00:00,  1.71it/s][0m
Denoise sigma=15: psnr: 34.15, ssim: 0.9357
Average Inference Time: 39795.82 ms
bsd68/test ok psnr_g15:34.1456 ssim_g15:0.9357,
Start bsd68/ testing Sigma=25...
 13%|[35m█████▊                                      [0m| 9/68 [00:05<00:34,  1.69it/s][0m^C
 13%|[35m█████▊                                      [0m| 9/68 [00:05<00:38,  1.53it/s][0m
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/PIL/ImageFile.py", line 643, in _save
    fh = fp.fileno()
         ^^^^^^^^^
AttributeError: '_idat' object has no attribute 'fileno'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/kaggle/working/DFPIR-main/test_3D_DFPIR.py", lin

'/kaggle/working/DFPIR-main'