In [1]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = "mps"
model_name = "celinelee/bartlarge_risctoarm_cloze2048"

In [8]:
bart_model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

bart_tokenizer.padding_side = "right"
bart_tokenizer.pad_token = bart_tokenizer.eos_token

In [40]:
source_lang = "RISC-V"
target_lang = "ARM64"
source_code = "\t.file\t\"cat.c\"\n\t.option pic\n\t.text\n\t.align\t1\n\t.globl\tmain\n\t.type\tmain, @function\nmain:\n\taddi\tsp,sp,-160\n\tsd\tra,152(sp)\n\tsd\ts0,144(sp)\n\taddi\ts0,sp,160\n\tmv\ta5,a0\n\tsd\ta1,-160(s0)\n\tsw\ta5,-148(s0)\n\tla\ta5,__stack_chk_guard\n\tld\ta4, 0(a5)\n\tsd\ta4, -24(s0)\n\tli\ta4, 0\n\tld\ta5,-160(s0)\n\taddi\ta5,a5,8\n\tld\ta5,0(a5)\n\tli\ta1,0\n\tmv\ta0,a5\n\tcall\topen@plt\n\tmv\ta5,a0\n\tsw\ta5,-136(s0)\n\tj\t.L2\n.L3:\n\tlw\ta4,-132(s0)\n\taddi\ta5,s0,-128\n\tmv\ta2,a4\n\tmv\ta1,a5\n\tli\ta0,1\n\tcall\twrite@plt\n.L2:\n\taddi\ta4,s0,-128\n\tlw\ta5,-136(s0)\n\tli\ta2,99\n\tmv\ta1,a4\n\tmv\ta0,a5\n\tcall\tread@plt\n\tmv\ta5,a0\n\tsw\ta5,-132(s0)\n\tlw\ta5,-132(s0)\n\tsext.w\ta5,a5\n\tbne\ta5,zero,.L3\n\tli\ta0,10\n\tcall\tputchar@plt\n\tlw\ta5,-136(s0)\n\tmv\ta0,a5\n\tcall\tclose@plt\n\tli\ta5,0\n\tmv\ta4,a5\n\tla\ta5,__stack_chk_guard\n\tld\ta3, -24(s0)\n\tld\ta5, 0(a5)\n\txor\ta5, a3, a5\n\tli\ta3, 0\n\tbeq\ta5,zero,.L5\n\tcall\t__stack_chk_fail@plt\n.L5:\n\tmv\ta0,a4\n\tld\tra,152(sp)\n\tld\ts0,144(sp)\n\taddi\tsp,sp,160\n\tjr\tra\n\t.size\tmain, .-main\n\t.ident\t\"GCC: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0\"\n\t.section\t.note.GNU-stack,\"\",@progbits\n"

text = f"Convert the following {source_lang} assembly code to {target_lang} assembly:\n```{source_lang.lower()}asm\n{source_code}```"

tokens = bart_tokenizer(
    source_code,
    return_tensors="pt",
    padding=True,
    truncation=True,
).to(device)

In [45]:
kwargs = {
    "output_attentions": True,
    "return_dict_in_generate": True,
    "output_scores": True,
    "num_beams": 5,
    "num_return_sequences": 5
}
output = bart_model.generate(
    **tokens,
    max_new_tokens=4096,
    temperature=0.7,
    do_sample=True,
    **kwargs
)

  test_elements = torch.tensor(test_elements)


In [46]:
decoded = bart_tokenizer.batch_decode(
    output.sequences,
    skip_special_tokens=True,
    **kwargs
)

In [47]:
print(decoded[0])

main:
	.zero	cfi_startproc
	stp	x29, x30, [sp, -160]!
	1, -144]!fi_def_cfa_offset 160
	str	x2, [x0]
	add	x0, sp, 32
	mov	x1, x8
	ldp	q0, q1, [w1]
#APP
// 20 "program193680.c" 1
	isr
// 0 "" 2
#NO_APP
	bl	open
	and	w0, w0, 255
	b	.L2
.L4:
.ldrsw	x3, [xtw	x4, w3
	sxtwxy	w	w1, w1
	lsl	x5, x1, 5
	lsr	x6, x5, 4
	orr	x7, x4, x7
	cmp	w2, 0
	csel	x8, x2, x6, ge
	sub	x20, x3, x0
	nop
	umov	w3, v0.b[2]

	fmov	d0, 1.0e+0
.8
.p2align 4,,11
	adrp2, :got:__isoc99_chk_guard
.x2fa, lsl #12
	scvtf	s0, s0.2
	fcmp	s3, #1.8b}, [x11]
.word	10
	ubfx	x11, x10, 0, 6
	adfm	PSTL1K, s3, -1.4
	fdiv	x9, x9, s1
.cmn	sp, #0.cset	w4, eq
	eor	x10, x11, s5
	{main}.L3K:
#intz	x12, .-main
	br	x17
	dsb	sy
	xword	0, x20
	prfm	PLDL3K
	movtzs	x19, x19
	csinc	x18, x18, mi
	udiv	x16, x17, x16
	msr, x22, s4
.global	__aarch64_cas4_acq_rel
.align	3
.type	a, %function

