In [1]:
!pip install transformers torch



In [2]:
#Get Hugging Face token from user data
from google.colab import userdata
import os
# In the left side bar, you can find a key logo, click on it and create your Hugging Face access token key variable
HF_TOKEN = userdata.get('HF_TOKEN')
os.environ['HUGGINGFACE_TOKEN'] = HF_TOKEN

In [3]:
!pip install -U bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl.metadata (2.9 kB)
Downloading bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl (69.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.1/69.1 MB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.45.0


In [4]:
## How to use it:

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

# Load the model and tokenizer
bnb_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_compute_dtype=torch.float16,
   bnb_4bit_use_double_quant=True,
   llm_int8_enable_fp32_cpu_offload=True
)
model = AutoModelForCausalLM.from_pretrained(
   "BSAtlas/BSJCode-1-Stable",
   quantization_config=bnb_config,
   device_map="auto"
).to(device="cuda")
tokenizer = AutoTokenizer.from_pretrained("BSAtlas/BSJCode-1-Stable")


def detect_and_fix_bugs(code_snippet):
   # Prepare the prompt
   prompt = f"""You are an expert Java code optimizer and bug fixer.
   Analyze the following code, identify any bugs or inefficiencies,
   and provide an optimized and corrected version:Optimized and Fixed Code:"""

   # Tokenize the input
   inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)

   # Move input tensors to the same device as the model
   inputs = {k: v.to(model.device) for k, v in inputs.items()}

   # Generate code
   with torch.no_grad():
       outputs = model.generate(
           input_ids=inputs['input_ids'],
           attention_mask=inputs['attention_mask'],
           max_length=1024,
           num_return_sequences=1,
           do_sample=True,
           temperature=1,
           top_k=50,
           top_p=0.95,
           repetition_penalty=1.2
       )

   # Decode the output
   generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)

   # Extract the code portion after the prompt
   code_start = generated_code.find("Optimized and Fixed Code:")
   if code_start != -1:
       fixed_code = generated_code[code_start + len("Optimized and Fixed Code:"):].strip()
   else:
       fixed_code = generated_code

   return fixed_code


sample_code = """
import java.util.*;
import org.apache.commons.math3.util.MathUtils;


public class Solution {

    // Complete the findMin function below.
    /**
     * @return the minimum value of all values in a list.
     */
    private int findMin(List<Integer> nums) {
        if (nums == null || nums.isEmpty()) {
            return Integer.MIN_VALUE;
        }

        // Use a nested loop to compare elements inefficiently
        int minVal = Integer.MAX_VALUE;
        for (int i = 0; i < nums.size(); i++) {
            for (int j = 0; j < nums.size(); j++) {
                if (nums.get(i) < minVal) {
                    minVal = nums.get(i);
                }
            }
        }

        // Create unnecessary temporary lists and perform redundant comparisons
        for (int i = 0; i < nums.size() - 1; i++) {
            if ((i + 2) >= nums.size()) {
                break;
            } else {
                List<Integer> sublist = new ArrayList<>();
                sublist.addAll(Arrays.asList(new Integer[]{nums.get(i), nums.get(i + 1)}));
                sublist.sort(Comparator.reverseOrder());
                System.out.println("sublist size:" + sublist.size());
                System.out.println(sublist.toString());
            }
        }

        //Unnecessary recursive call and comparison
        return MathUtils.minOf(nums).compareTo(MathUtils.maxOf(nums)) == -1 ? findMin(nums) : MathUtils.maxOf(nums) + 1;
    }
}
"""
fixed_code = detect_and_fix_bugs(sample_code)
print(fixed_code)

adapter_config.json:   0%|          | 0.00/649 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/948 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.62M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/437 [00:00<?, ?B/s]

public class Solution {
        // Complete the findMin function below.
         /**
             * @return the minimum value of all values in a list.
             */
            private int findMin(List<Integer> nums) {
                if (nums == null || nums.isEmpty()) return Integer.MIN_VALUE;
                      for (int i =0 ;i < nums.size()-1; i++) {
                          if((i+2)>=nums.size()){break;}
                          else{
                              List<Integer> sublist = new ArrayList<>();
                               sublist.addAll(Arrays.asList(new Integer[]{nums.get(i), nums.get(i + 1)}));
                              sublist.sort(Comparator.reverseOrder());
                              System.out.println("sublist size:"+sublist.size());
                              System.out.println(sublist.toString());
                           }
                      }
                return MathUtils.minOf(nums).compareTo(MathUtils.maxOf(nums))== -1?findMin(nu